diff --git a/.bazelrc b/.bazelrc index d7ae76f096431a..02dec0349c4741 100644 --- a/.bazelrc +++ b/.bazelrc @@ -51,16 +51,13 @@ # Remote build execution options (only configured to work with TF team projects for now.) # rbe_base: General RBE options shared by all flavors. # rbe_linux: General RBE options used on all linux builds. -# rbe_win: General RBE options used on all windows builds. +# rbe_win_base: General RBE options used on all Windows builds. Not to be used standalone. +# rbe_win_clang: Options specific to compiling using Clang. # # rbe_linux_cpu: RBE options to build with only CPU support. # rbe_linux_cuda: RBE options to build with GPU support using clang. # rbe_linux_cuda_nvcc: RBE options to build with GPU support using nvcc. # -# rbe_win_py39: Windows Python 3.9 RBE config -# -# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows -# # Embedded Linux options (experimental and only tested with TFLite build yet) # elinux: General Embedded Linux options shared by all flavors. # elinux_aarch64: Embedded Linux options for aarch64 (ARM64) CPU support. @@ -450,6 +447,17 @@ build:win_clang --host_linkopt=/FORCE:MULTIPLE test:win_clang --linkopt=/FORCE:MULTIPLE test:win_clang --host_linkopt=/FORCE:MULTIPLE +# Same config as above but for XLA, which has different toolchain paths +build:win_clang_xla --copt=/clang:-Weverything +build:win_clang_xla --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl +build:win_clang_xla --extra_execution_platforms=//tools/toolchains/win:x64_windows-clang-cl +build:win_clang_xla --host_platform=//tools/toolchains/win:x64_windows-clang-cl +build:win_clang_xla --compiler=clang-cl +build:win_clang_xla --linkopt=/FORCE:MULTIPLE +build:win_clang_xla --host_linkopt=/FORCE:MULTIPLE +test:win_clang_xla --linkopt=/FORCE:MULTIPLE +test:win_clang_xla --host_linkopt=/FORCE:MULTIPLE + # Options to build TensorFlow 1.x or 2.x. # TODO(kanglan): Change v2's define to default behavior build:v2 --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1 @@ -546,38 +554,25 @@ build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda build:rbe_linux_cuda_nvcc --config=nvcc_clang build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 -# TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed -build:rbe_win --config=rbe_base -build:rbe_win --crosstool_top="//tensorflow/tools/toolchains/win/tf_win_05022023:toolchain" -build:rbe_win --extra_toolchains="//tensorflow/tools/toolchains/win/tf_win_05022023:cc-toolchain-x64_windows" -build:rbe_win --extra_execution_platforms="//tensorflow/tools/toolchains/win:rbe_windows_ltsc2019" -build:rbe_win --host_platform="//tensorflow/tools/toolchains/win:rbe_windows_ltsc2019" -build:rbe_win --platforms="//tensorflow/tools/toolchains/win:rbe_windows_ltsc2019" -build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe -build:rbe_win --experimental_strict_action_env=true - -# TODO(gunan): Remove once we use MSVC 2019 with latest patches. -build:rbe_win --define=override_eigen_strong_inline=true - +build:rbe_win_base --config=rbe_base +build:rbe_win_base --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe +build:rbe_win_base --remote_instance_name=projects/tensorflow-testing/instances/windows # Don't build the python zip archive in the RBE build. -build:rbe_win --remote_download_minimal -build:rbe_win --enable_runfiles -build:rbe_win --nobuild_python_zip - -build:rbe_win_py38 --config=rbe_base -build:rbe_win_py38 --repo_env=PYTHON_BIN_PATH=C:\\Python38\\python.exe -build:rbe_win_py38 --repo_env=PYTHON_LIB_PATH=C:\\Python38\\lib\\site-packages -build:rbe_win_py38 --repo_env=TF_PYTHON_CONFIG_REPO=//tensorflow/tools/toolchains/win_1803/py38 -build:rbe_win_py38 --python_path=C:\\Python38\\python.exe - -build:rbe_win_py39 --config=rbe_base -build:rbe_win_py39 --repo_env=PYTHON_BIN_PATH=C:\\Python39\\python.exe -build:rbe_win_py39 --repo_env=PYTHON_LIB_PATH=C:\\Python39\\lib\\site-packages -build:rbe_win_py39 --repo_env=TF_PYTHON_CONFIG_REPO=//tensorflow/tools/toolchains/win_1803/py39 -build:rbe_win_py39 --python_path=C:\\Python39\\python.exe - -# TODO(kanglan): Merge tensorflow_testing_rbe_win into rbe_win -common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows +build:rbe_win_base --remote_download_minimal +build:rbe_win_base --enable_runfiles +build:rbe_win_base --nobuild_python_zip +build:rbe_win_base --define=override_eigen_strong_inline=true + +build:rbe_win_clang --config=rbe_win_base +build:rbe_win_clang --crosstool_top="//tensorflow/tools/toolchains/win/20240424:toolchain" +build:rbe_win_clang --extra_toolchains="//tensorflow/tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" +build:rbe_win_clang --extra_execution_platforms="//tensorflow/tools/toolchains/win:x64_windows-clang-cl" +build:rbe_win_clang --host_platform="//tensorflow/tools/toolchains/win:x64_windows-clang-cl" +build:rbe_win_clang --platforms="//tensorflow/tools/toolchains/win:x64_windows-clang-cl" +build:rbe_win_clang --compiler=clang-cl +build:rbe_win_clang --linkopt=/FORCE:MULTIPLE +build:rbe_win_clang --host_linkopt=/FORCE:MULTIPLE + # END TF REMOTE BUILD EXECUTION OPTIONS # TFLite build configs for generic embedded Linux @@ -815,7 +810,7 @@ test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflo build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 -# TODO(michaelhudgins): Why do we need to specifically omit go and java here? +# TODO(michaelhudgins): Why do we need to specifically omit go and java here? build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test # CROSS-COMPILE ARM64 PYCPP build:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test @@ -924,7 +919,9 @@ build:cross_compile_macos_x86 --extra_toolchains=//tensorflow/tools/toolchains/c build:cross_compile_macos_x86 --platform_mappings=tensorflow/tools/toolchains/cross_compile/config/platform_mappings # RBE cross-compile configs for Darwin x86 -build:rbe_cross_compile_macos_x86 --config=cross_compile_macos_x86 +build:rbe_cross_compile_macos_x86 --config=cross_compile_macos_x86 --remote_download_minimal +build:rbe_cross_compile_macos_x86 --bes_backend="" --bes_results_url="" --bes_timeout="0s" +build:rbe_cross_compile_macos_x86 --experimental_remote_build_event_upload="minimal" build:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base build:rbe_cross_compile_macos_x86 --bes_upload_mode=nowait_for_upload_complete test:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base diff --git a/RELEASE.md b/RELEASE.md index 3c6198b60d1918..8287988507e571 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -31,6 +31,7 @@ been added to TF binary distributions (Python wheels). * Replace `DebuggerOptions` of TensorFlow Quantizer, and migrate to `DebuggerConfig` of StableHLO Quantizer. +* Add TensorFlow to StableHLO converter to TensorFlow pip package. ## Keras @@ -87,6 +88,8 @@ * The Python TF Lite Interpreter bindings now have an option `experimental_default_delegate_latest_features` to enable all default delegate features. + * Flatbuffer version update: + * `GetTemporaryPointer()` bug fixed. * `tf.data` * Add `wait` to `tf.data.Dataset.load`. If `True`, for snapshots written @@ -95,6 +98,13 @@ it's finished. The default is `False` for backward compatibility. Users of `distributed_save` are recommended to set it to `True`. +* `tf.tpu.experimental.embedding.TPUEmbeddingV2` + * Add `compute_sparse_core_stats` for sparse core users to profile the + data with this API to get the `max_ids` and `max_unique_ids`. These + numbers will be needed to configure the sparse core embedding mid level + api. + * Remove the `preprocess_features` method since that's no longer needed. + ## Thanks to our Contributors This release contains contributions from many people at Google, as well as: diff --git a/ci/official/containers/linux_arm64/build.sh b/ci/official/containers/linux_arm64/build.sh index 5d6a40658bd782..611d5f48ac0084 100755 --- a/ci/official/containers/linux_arm64/build.sh +++ b/ci/official/containers/linux_arm64/build.sh @@ -40,11 +40,15 @@ else fi fi +# TODO(b/341050361): When these steps are verified, removed the GCR image code. +AR_IMAGE_PATH="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64" + # Build for both JAX and TF usage. We do these in one place because they share # almost all of the same cache layers export DOCKER_BUILDKIT=1 for target in jax tf; do IMAGE="gcr.io/tensorflow-sigs/build-arm64:$target-$TAG" + AR_IMAGE="$AR_IMAGE_PATH:$target-$TAG" docker pull "$IMAGE" || true # Due to some flakiness of resources pulled in the build, allow the docker # command to reattempt build a few times in the case of failure (b/302558736) @@ -55,7 +59,7 @@ for target in jax tf; do --build-arg REQUIREMENTS_FILE=jax.requirements.txt \ --target=$target \ --cache-from "$IMAGE" \ - -t "$IMAGE" . && break + -t "$IMAGE" -t "$AR_IMAGE" . && break done final=$? if [ $final -ne 0 ]; then @@ -66,5 +70,7 @@ for target in jax tf; do if [[ -n "$KOKORO_BUILD_ID" ]]; then gcloud auth configure-docker docker push "$IMAGE" + gcloud auth configure-docker us-central1-docker.pkg.dev + docker push "$AR_IMAGE" fi done diff --git a/ci/official/utilities/setup_docker.sh b/ci/official/utilities/setup_docker.sh index 36afa2545eb244..91618c75f3ba51 100755 --- a/ci/official/utilities/setup_docker.sh +++ b/ci/official/utilities/setup_docker.sh @@ -14,11 +14,12 @@ # limitations under the License. # ============================================================================== if [[ "$TFCI_DOCKER_PULL_ENABLE" == 1 ]]; then - # Simple retry logic for docker-pull errors. Sleeps for 15s if a pull fails. + # Simple retry logic for docker-pull errors. Sleeps if a pull fails. # Pulling an already-pulled container image will finish instantly, so # repeating the command costs nothing. docker pull "$TFCI_DOCKER_IMAGE" || sleep 15 - docker pull "$TFCI_DOCKER_IMAGE" || sleep 15 + docker pull "$TFCI_DOCKER_IMAGE" || sleep 30 + docker pull "$TFCI_DOCKER_IMAGE" || sleep 60 docker pull "$TFCI_DOCKER_IMAGE" fi diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt index 05dc3940487eef..f17468ddaafd0a 100644 --- a/requirements_lock_3_10.txt +++ b/requirements_lock_3_10.txt @@ -522,9 +522,9 @@ urllib3==2.2.0 \ --hash=sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20 \ --hash=sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224 # via requests -werkzeug==3.0.1 \ - --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \ - --hash=sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10 +werkzeug==3.0.3 \ + --hash=sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18 \ + --hash=sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8 # via tb-nightly wheel==0.41.3 \ --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ diff --git a/requirements_lock_3_11.txt b/requirements_lock_3_11.txt index 05dc3940487eef..f17468ddaafd0a 100644 --- a/requirements_lock_3_11.txt +++ b/requirements_lock_3_11.txt @@ -522,9 +522,9 @@ urllib3==2.2.0 \ --hash=sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20 \ --hash=sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224 # via requests -werkzeug==3.0.1 \ - --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \ - --hash=sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10 +werkzeug==3.0.3 \ + --hash=sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18 \ + --hash=sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8 # via tb-nightly wheel==0.41.3 \ --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ diff --git a/requirements_lock_3_12.txt b/requirements_lock_3_12.txt index 120ec6ebcd7c72..0d045ea1a0579c 100644 --- a/requirements_lock_3_12.txt +++ b/requirements_lock_3_12.txt @@ -530,9 +530,9 @@ urllib3==2.2.0 \ --hash=sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20 \ --hash=sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224 # via requests -werkzeug==3.0.1 \ - --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \ - --hash=sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10 +werkzeug==3.0.3 \ + --hash=sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18 \ + --hash=sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8 # via tb-nightly wheel==0.41.3 \ --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ diff --git a/requirements_lock_3_9.txt b/requirements_lock_3_9.txt index 36a55514cd788b..48c74173fe553f 100644 --- a/requirements_lock_3_9.txt +++ b/requirements_lock_3_9.txt @@ -526,9 +526,9 @@ urllib3==2.2.0 \ --hash=sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20 \ --hash=sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224 # via requests -werkzeug==3.0.1 \ - --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \ - --hash=sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10 +werkzeug==3.0.3 \ + --hash=sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18 \ + --hash=sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8 # via tb-nightly wheel==0.41.3 \ --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 71487e2aec0bee..a4cd4af8975bc2 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -1382,6 +1382,7 @@ tf_cc_shared_library( "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/lite/sparsity:sparsify_model", "//tensorflow/compiler/mlir/quantization/stablehlo/python:pywrap_quantization_lib_impl", + "//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python:pywrap_tensorflow_to_stablehlo_lib_impl", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op", "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc_impl", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", @@ -1416,6 +1417,7 @@ tf_cc_shared_library( "//tensorflow/core/grappler:grappler_item_builder", "//tensorflow/core/kernels:data_service_ops", "//tensorflow/core/kernels:dataset_ops", + "//tensorflow/core/tpu/kernels:sparse_core_layout", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:path", "//tensorflow/core/platform:stacktrace_handler", diff --git a/tensorflow/c/c_test.c b/tensorflow/c/c_test.c index ce8a115c5b21bd..5415b2deaf6c93 100644 --- a/tensorflow/c/c_test.c +++ b/tensorflow/c/c_test.c @@ -20,6 +20,10 @@ limitations under the License. #include #include +#ifdef _WIN32 +#include +#endif + #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/env.h" diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index a433b618de7142..d20be8abcf02a4 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -296,8 +296,8 @@ TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) { void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { if (h == nullptr) return; - tensorflow::profiler::TraceMe activity( - "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMe activity("TFE_DeleteTensorHandle", + tsl::profiler::TraceMeLevel::kInfo); if (h) { tensorflow::unwrap(h)->Unref(); } diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 8422459c21b529..ab29b1cd6ff051 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -216,7 +216,7 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name, Status status = unwrap(op)->SetAttrType(attr_name, static_cast(value)); TF_SetStatus(s, static_cast(status.code()), - tsl::NullTerminatedMessage(status)); + absl::StatusMessageAsCStr(status)); } void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs, diff --git a/tensorflow/c/experimental/filesystem/plugins/windows/BUILD b/tensorflow/c/experimental/filesystem/plugins/windows/BUILD index 159e36e485e6a6..8cb30fa9ae0828 100644 --- a/tensorflow/c/experimental/filesystem/plugins/windows/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/windows/BUILD @@ -31,8 +31,5 @@ cc_library( "nobuilder", "notap", ], - deps = [ - "//tensorflow/c:tf_status", - "//tensorflow/c/experimental/filesystem:filesystem_interface", - ], + deps = ["//tensorflow/c/experimental/filesystem:filesystem_interface"], ) diff --git a/tensorflow/c/experimental/filesystem/plugins/windows/windows_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/windows/windows_filesystem.cc index 31400562a2579e..e382b829341411 100644 --- a/tensorflow/c/experimental/filesystem/plugins/windows/windows_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/windows/windows_filesystem.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" -#include "tensorflow/c/tf_status.h" // Implementation of a filesystem for POSIX environments. // This filesystem will support `file://` and empty (local) URI schemes. diff --git a/tensorflow/c/experimental/next_pluggable_device/c_api.cc b/tensorflow/c/experimental/next_pluggable_device/c_api.cc index 15a50a0a7c4060..a4d47753bdd5d8 100644 --- a/tensorflow/c/experimental/next_pluggable_device/c_api.cc +++ b/tensorflow/c/experimental/next_pluggable_device/c_api.cc @@ -91,7 +91,7 @@ void TF_LookupOrCreatePluginResource( void* opaque_plugin_resource = create_func(create_func_args); *new_resource = new tensorflow::PluginResource( opaque_plugin_resource, plugin_resource_name, delete_func); - return tensorflow::OkStatus(); + return absl::OkStatus(); }); if (cc_status.ok()) { diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.cc b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.cc index 18a851e394aea7..dd15c9f078cd1d 100644 --- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.cc +++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.cc @@ -69,7 +69,7 @@ absl::Status SetPjRtCBufferToTensor(PJRT_Buffer* c_buffer, absl::StatusOr GetPjRtCApiClient( const DeviceType& device_type) { - TF_ASSIGN_OR_RETURN(tsl::StatusOr pjrt_client, + TF_ASSIGN_OR_RETURN(absl::StatusOr pjrt_client, tensorflow::GetPjRtClient(device_type)); auto* pjrt_c_api_client = dynamic_cast(*pjrt_client); if (pjrt_c_api_client == nullptr) { diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h index 85c60120b07241..c2b1051f75c39e 100644 --- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h +++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { -StatusOr GetPjRtCBufferFromTensor(const Tensor* tensor); +absl::StatusOr GetPjRtCBufferFromTensor(const Tensor* tensor); absl::Status SetPjRtCBufferToTensor(PJRT_Buffer* c_buffer, xla::PjRtCApiClient* c_api_client, diff --git a/tensorflow/c/experimental/ops/gen/common/case_format.cc b/tensorflow/c/experimental/ops/gen/common/case_format.cc index 9b8e955356db07..1e9d123005e8a4 100644 --- a/tensorflow/c/experimental/ops/gen/common/case_format.cc +++ b/tensorflow/c/experimental/ops/gen/common/case_format.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/common/case_format.h" +#include "tensorflow/core/platform/str_util.h" +#include "tensorflow/core/platform/types.h" + namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/common/case_format_test.cc b/tensorflow/c/experimental/ops/gen/common/case_format_test.cc index 37bc5be753fd64..302bcc42453169 100644 --- a/tensorflow/c/experimental/ops/gen/common/case_format_test.cc +++ b/tensorflow/c/experimental/ops/gen/common/case_format_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/experimental/ops/gen/common/case_format.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/common/controller.cc b/tensorflow/c/experimental/ops/gen/common/controller.cc index a8e02f41011d32..cafb57c0919403 100644 --- a/tensorflow/c/experimental/ops/gen/common/controller.cc +++ b/tensorflow/c/experimental/ops/gen/common/controller.cc @@ -15,11 +15,17 @@ limitations under the License. #include "tensorflow/c/experimental/ops/gen/common/controller.h" #include "absl/strings/substitute.h" +#include "tensorflow/c/experimental/ops/gen/common/path_config.h" +#include "tensorflow/c/experimental/ops/gen/common/source_code.h" +#include "tensorflow/c/experimental/ops/gen/model/op_spec.h" +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/op.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/path.h" +#include "tsl/platform/status.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/common/path_config.cc b/tensorflow/c/experimental/ops/gen/common/path_config.cc index d9e3881bf15580..b8f84d5f31f4d3 100644 --- a/tensorflow/c/experimental/ops/gen/common/path_config.cc +++ b/tensorflow/c/experimental/ops/gen/common/path_config.cc @@ -16,7 +16,9 @@ limitations under the License. #include +#include "absl/strings/str_join.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/common/source_code.cc b/tensorflow/c/experimental/ops/gen/common/source_code.cc index ea4db53d167109..ea2b66fac7cd27 100644 --- a/tensorflow/c/experimental/ops/gen/common/source_code.cc +++ b/tensorflow/c/experimental/ops/gen/common/source_code.cc @@ -14,9 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/common/source_code.h" +#include "absl/strings/ascii.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stringpiece.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/common/view_util.cc b/tensorflow/c/experimental/ops/gen/common/view_util.cc index a14c7e38b63b46..7c8717067b08fe 100644 --- a/tensorflow/c/experimental/ops/gen/common/view_util.cc +++ b/tensorflow/c/experimental/ops/gen/common/view_util.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/common/view_util.h" +#include "absl/strings/str_join.h" #include "absl/strings/substitute.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc index 7ca7a7bf639bf7..509f209ffd7b42 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc @@ -44,7 +44,7 @@ Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx, // Note that if shape is unknown rank, shape.dim_sizes() will be empty, and // shape.dims() will be -1. - gtl::InlinedVector dim_sizes = shape.dim_sizes(); + absl::InlinedVector dim_sizes = shape.dim_sizes(); TF_RETURN_IF_ERROR(varhandle_op->SetAttrShape( "shape", reinterpret_cast(dim_sizes.data()), shape.dims())); diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc index 0701e3b9aa9fff..bc0fae5fd9aeb9 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc @@ -506,12 +506,11 @@ TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) { tensorflow::down_cast( tensorflow::unwrap(saved_model)); tensorflow::Variable* uninitialized_variable; - ASSERT_EQ(::tensorflow::OkStatus(), - model_api->GetVariable("uninitialized_variable", - &uninitialized_variable)); + ASSERT_EQ(absl::OkStatus(), model_api->GetVariable("uninitialized_variable", + &uninitialized_variable)); ASSERT_EQ(tensorflow::DT_FLOAT, uninitialized_variable->dtype()); - ASSERT_EQ(::tensorflow::OkStatus(), + ASSERT_EQ(absl::OkStatus(), model_api->GetVariable("sub_module.uninitialized_variable", &uninitialized_variable)); ASSERT_EQ(tensorflow::DT_INT64, uninitialized_variable->dtype()); diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 93d07b431ee4cf..65b31f8cfb8f1c 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -200,14 +200,16 @@ void HostCallbackTrampoline(void* ctx, TF_Status* status) { delete host_ctx; } -class CStreamExecutor : public StreamExecutorInterface { +class CStreamExecutor : public StreamExecutor { public: - explicit CStreamExecutor(SP_Device device, SP_DeviceFns* device_fns, + explicit CStreamExecutor(Platform* se_platform, SP_Device device, + SP_DeviceFns* device_fns, SP_StreamExecutor* stream_executor, SP_Platform* platform, SP_PlatformFns* platform_fns, SP_TimerFns* timer_fns, const std::string& name, int visible_device_count) - : device_(std::move(device)), + : StreamExecutor(se_platform), + device_(std::move(device)), device_fns_(device_fns), stream_executor_(stream_executor), platform_(platform), @@ -405,10 +407,6 @@ class CStreamExecutor : public StreamExecutorInterface { return stream_executor_->host_callback(&device_, stream_handle, &HostCallbackTrampoline, ctx); } - absl::Status AllocateEvent(Event* event) override { - DCHECK(event != nullptr); - return static_cast(event->implementation())->Create(); - } absl::Status DeallocateEvent(Event* event) override { static_cast(event->implementation())->Destroy(); return absl::OkStatus(); @@ -436,14 +434,6 @@ class CStreamExecutor : public StreamExecutorInterface { stream_executor_->get_event_status(&device_, event_handle); return SEEventStatusToEventStatus(event_status); } - bool AllocateStream(Stream* stream) override { - DCHECK(stream != nullptr); - absl::Status status = - static_cast(stream->implementation())->Create(); - // TODO(annarev): update AllocateStream to return status instead - // (similar to AllocateEvent). - return status.ok(); - } void DeallocateStream(Stream* stream) override { static_cast(stream->implementation())->Destroy(); } @@ -557,15 +547,19 @@ class CStreamExecutor : public StreamExecutorInterface { return builder.Build(); } - // Each call creates a new instance of the platform-specific implementation of - // the corresponding interface type. - std::unique_ptr CreateEventImplementation() override { - return std::unique_ptr( - new CEvent(&device_, stream_executor_)); + absl::StatusOr> CreateEvent() override { + auto c_event = std::make_unique(&device_, stream_executor_); + TF_RETURN_IF_ERROR(c_event->Create()); + return std::make_unique(this, std::move(c_event)); } - std::unique_ptr GetStreamImplementation() override { - return std::unique_ptr( - new CStream(&device_, stream_executor_)); + + absl::StatusOr> CreateStream( + std::optional> priority = + std::nullopt) override { + auto c_stream = std::make_unique(&device_, stream_executor_); + TF_RETURN_IF_ERROR(c_stream->Create()); + auto stream = std::make_unique(this, std::move(c_stream)); + return std::move(stream); } private: @@ -644,11 +638,9 @@ absl::StatusOr> CPlatform::GetUncachedExecutor( c_status.get()); TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get())); - auto executor = std::make_unique( - std::move(device), &device_fns_, &stream_executor_, &platform_, + return std::make_unique( + this, std::move(device), &device_fns_, &stream_executor_, &platform_, &platform_fns_, &timer_fns_, name_, visible_device_count); - auto result = std::make_unique(this, std::move(executor)); - return result; } absl::Status InitStreamExecutorPlugin(void* dso_handle, diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc index 56f25a5811293e..680a1d9d1db1f5 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc @@ -342,11 +342,10 @@ TEST_F(StreamExecutorTest, CreateEvent) { StreamExecutor* executor = GetExecutor(0); ASSERT_FALSE(event_created); - Event* event = new Event(executor); - event->Init(); + TF_ASSERT_OK_AND_ASSIGN(auto event, executor->CreateEvent()); ASSERT_TRUE(event_created); ASSERT_FALSE(event_deleted); - delete event; + event.reset(); ASSERT_TRUE(event_deleted); } @@ -365,11 +364,10 @@ TEST_F(StreamExecutorTest, PollForEventStatus) { }; StreamExecutor* executor = GetExecutor(0); - Event event(executor); - event.Init(); - ASSERT_EQ(event.PollForStatus(), Event::Status::kComplete); + TF_ASSERT_OK_AND_ASSIGN(auto event, executor->CreateEvent()); + ASSERT_EQ(event->PollForStatus(), Event::Status::kComplete); event_status = SE_EVENT_ERROR; - ASSERT_EQ(event.PollForStatus(), Event::Status::kError); + ASSERT_EQ(event->PollForStatus(), Event::Status::kError); } TEST_F(StreamExecutorTest, RecordAndWaitForEvent) { @@ -403,14 +401,13 @@ TEST_F(StreamExecutorTest, RecordAndWaitForEvent) { }; StreamExecutor* executor = GetExecutor(0); - Event event(executor); - event.Init(); + TF_ASSERT_OK_AND_ASSIGN(auto event, executor->CreateEvent()); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); ASSERT_FALSE(record_called); - TF_ASSERT_OK(stream->RecordEvent(&event)); + TF_ASSERT_OK(stream->RecordEvent(event.get())); ASSERT_TRUE(record_called); ASSERT_FALSE(wait_called); - TF_ASSERT_OK(stream->WaitFor(&event)); + TF_ASSERT_OK(stream->WaitFor(event.get())); ASSERT_TRUE(wait_called); } diff --git a/tensorflow/c/kernels_experimental.cc b/tensorflow/c/kernels_experimental.cc index 26173507f29aec..02e41428c6ae58 100644 --- a/tensorflow/c/kernels_experimental.cc +++ b/tensorflow/c/kernels_experimental.cc @@ -266,7 +266,7 @@ void TF_AssignUpdateVariable(TF_OpKernelContext* ctx, int input_index, Status status = LookupResource(context, HandleFromInput(context, input_index), &variable); if (!status.ok()) { - printf("Failed with error: %s\n", tsl::NullTerminatedMessage(status)); + printf("Failed with error: %s\n", absl::StatusMessageAsCStr(status)); abort(); } const Tensor& value = context->input(value_index); diff --git a/tensorflow/c/tf_status_helper.cc b/tensorflow/c/tf_status_helper.cc index bbeae6f76bc497..c96a5af7440dff 100644 --- a/tensorflow/c/tf_status_helper.cc +++ b/tensorflow/c/tf_status_helper.cc @@ -25,7 +25,7 @@ namespace tsl { void Set_TF_Status_from_Status(TF_Status* tf_status, const absl::Status& status) { TF_SetStatus(tf_status, TSLCodeFromStatusCode(status.code()), - tsl::NullTerminatedMessage(status)); + absl::StatusMessageAsCStr(status)); status.ForEachPayload( [tf_status](absl::string_view key, const absl::Cord& value) { std::string key_str(key); diff --git a/tensorflow/cc/experimental/libtf/impl/BUILD b/tensorflow/cc/experimental/libtf/impl/BUILD index 4f5b7ccfd84940..97b06b21682daa 100644 --- a/tensorflow/cc/experimental/libtf/impl/BUILD +++ b/tensorflow/cc/experimental/libtf/impl/BUILD @@ -39,6 +39,8 @@ tf_cc_test( ":scalars", ":string", ":tensor_spec", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", ], @@ -123,6 +125,8 @@ tf_cc_test( deps = [ ":iostream", # Necessary for absl::VerifyTypeImplementsAbslHashCorrectly. ":tensor_spec", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/hash:hash_testing", diff --git a/tensorflow/cc/experimental/libtf/impl/iostream_test.cc b/tensorflow/cc/experimental/libtf/impl/iostream_test.cc index 40c3d7550d00d4..dede1483d76187 100644 --- a/tensorflow/cc/experimental/libtf/impl/iostream_test.cc +++ b/tensorflow/cc/experimental/libtf/impl/iostream_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/cc/experimental/libtf/impl/scalars.h" #include "tensorflow/cc/experimental/libtf/impl/string.h" #include "tensorflow/cc/experimental/libtf/impl/tensor_spec.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/test.h" namespace tf { diff --git a/tensorflow/cc/experimental/libtf/impl/tensor_spec_test.cc b/tensorflow/cc/experimental/libtf/impl/tensor_spec_test.cc index e0654bec85fb29..dc07f77c7ba9b7 100644 --- a/tensorflow/cc/experimental/libtf/impl/tensor_spec_test.cc +++ b/tensorflow/cc/experimental/libtf/impl/tensor_spec_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/cc/experimental/libtf/impl/tensor_spec.h" #include "absl/hash/hash_testing.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/test.h" namespace tf { diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 6cc731e722d16b..da27e61d380081 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -544,9 +544,7 @@ cc_library( name = "fingerprinting_utils", srcs = ["fingerprinting_utils.cc"], hdrs = ["fingerprinting_utils.h"], - visibility = [ - "//tensorflow:__pkg__", - ], + visibility = ["//visibility:private"], deps = [ ":constants", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 93d7527bdd409f..4666ddd5db9ed6 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -668,7 +668,6 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { set_static_data_program_shape(data, StaticProgramShape()); set_static_data_hlo_profile_printer_data( data, StaticHloProfilePrinterData()); - set_static_data_use_xla_runtime(data, {{USE_XLA_RUNTIME}}); {{ASSIGN_PROFILE_COUNTERS_SIZE}} return data; }(); @@ -822,7 +821,6 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { {"{{DECLS_FROM_OBJ_FILE}}", absl::StrJoin(metadata_result.header_variable_decls, "\n")}, {"{{ENTRY}}", compile_result.entry_point}, - {"{{USE_XLA_RUNTIME}}", opts.use_xla_runtime ? "true" : "false"}, {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}", metadata_result.hlo_profile_printer_data_access_shim}, {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto}, diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 71b234a8385806..cd1a72308c3ede 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -97,7 +97,6 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { set_static_data_program_shape(data, StaticProgramShape()); set_static_data_hlo_profile_printer_data( data, StaticHloProfilePrinterData()); - set_static_data_use_xla_runtime(data, false); return data; }(); diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index a543aae5b92997..99c8541c55488c 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -319,8 +319,6 @@ def _tf_library( ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually # needed. - "@local_xla//xla/service/cpu/runtime:convolution_ffi", - "@local_xla//xla/service/cpu/runtime:rng_ffi", "@local_xla//xla/service/cpu:runtime_conv2d", "@local_xla//xla/service/cpu:runtime_custom_call_status", "@local_xla//xla/service/cpu:runtime_key_value_sort", diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 76f3c147903748..623334534567de 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -199,6 +199,7 @@ cc_library( "//tensorflow/core/tpu:tpu_node_device_util", "//tensorflow/core/tpu:virtual_device", "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/platform:statusor", "@local_xla//xla/stream_executor/tpu:c_api_conversions", "@local_xla//xla/stream_executor/tpu:status_helper", "@local_xla//xla/stream_executor/tpu:tpu_api", @@ -314,6 +315,7 @@ cc_library( "//tensorflow/core/common_runtime:dma_helper", "//tensorflow/core/framework:allocator", "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:statusor", "@local_xla//xla:util", "@local_xla//xla/client:global_data", "@local_xla//xla/client:local_client", @@ -1149,6 +1151,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_xla//xla:status_macros", diff --git a/tensorflow/compiler/jit/device_util.h b/tensorflow/compiler/jit/device_util.h index df3b7d04fbfe7b..ec4d9484ae8854 100644 --- a/tensorflow/compiler/jit/device_util.h +++ b/tensorflow/compiler/jit/device_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/numeric/bits.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -79,7 +80,7 @@ class DeviceSet { uint64 only_lowest_bit_set = word & -word; // The number of trailing zeros in a non-zero word is the index of the // least significant 1. - int bit_index = ctz_uint64(word); + int bit_index = absl::countr_zero(word); if (!func(DeviceId(word_index * kWordSize + bit_index))) { return; } @@ -89,20 +90,6 @@ class DeviceSet { } private: - static int ctz_uint64(uint64 x) { - DCHECK_NE(x, 0); -#ifdef __GNUC__ - return __builtin_ctzl(x); -#else - int result = 0u; - while ((x & 1u) == 0u) { - x >>= 1; - ++result; - } - return result; -#endif - } - absl::InlinedVector storage_; const int kWordSize = 64; diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 0ac326c61fb3ec..d173564b7fd10d 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -59,8 +59,10 @@ cc_library( "//tensorflow/compiler/jit:xla_compile_util", "//tensorflow/core/platform:refcount", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_xla//xla/pjrt:pjrt_client", + "@local_xla//xla/tsl/concurrency:async_value", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 9d75388cfbbe80..5a29e8ef36e9b3 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -24,10 +24,10 @@ limitations under the License. #include #include #include -#include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/device_compilation_profiler.h" @@ -52,7 +52,7 @@ limitations under the License. #include "xla/executable_run_options.h" #include "xla/pjrt/pjrt_client.h" #include "xla/service/gpu/gpu_executable_run_options.h" -#include "xla/statusor.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -224,7 +224,7 @@ xla::SendDeviceMemoryFunction GetSendDeviceMemoryFunction( int64_t channel_id, se::Stream* stream, const xla::Shape& shape, const se::DeviceMemoryBase& device_memory_base, const absl::flat_hash_map& frontend_attrs) - -> absl::StatusOr> { + -> absl::StatusOr>> { auto iter = frontend_attrs.find("_xla_host_transfer_rendezvous"); // Generate the Rendezvous key. @@ -244,12 +244,10 @@ xla::SendDeviceMemoryFunction GetSendDeviceMemoryFunction( RendezvousInterface::ParsedKey parsed_key; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(rendezvous_key, &parsed_key)); - tsl::AsyncValueRef done_event = - tsl::MakeConstructedAsyncValueRef(stream->parent()); - if (!done_event->Init()) { - return errors::Internal( - "Failed to initialize done event (channel_id=%d)", channel_id); - } + TF_ASSIGN_OR_RETURN(auto event, stream->parent()->CreateEvent()); + tsl::AsyncValueRef> done_event = + tsl::MakeConstructedAsyncValueRef>( + std::move(event)); Rendezvous::Args args; // Rendezvous::Args owns the device context pointer. @@ -273,7 +271,7 @@ xla::RecvDeviceMemoryFunction GetRecvDeviceMemoryFunction( int64_t channel_id, se::Stream* stream, const xla::Shape& shape, se::DeviceMemoryBase* device_memory_base, const absl::flat_hash_map& frontend_attrs) - -> absl::StatusOr> { + -> absl::StatusOr>> { auto iter = frontend_attrs.find("_xla_host_transfer_rendezvous"); // Generate the Rendezvous key. @@ -293,12 +291,10 @@ xla::RecvDeviceMemoryFunction GetRecvDeviceMemoryFunction( RendezvousInterface::ParsedKey parsed_key; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(rendezvous_key, &parsed_key)); - tsl::AsyncValueRef done_event = - tsl::MakeConstructedAsyncValueRef(stream->parent()); - if (!done_event->Init()) { - return errors::Internal( - "Failed to initialize done event (channel_id=%d)", channel_id); - } + TF_ASSIGN_OR_RETURN(auto event, stream->parent()->CreateEvent()); + tsl::AsyncValueRef> done_event = + tsl::MakeConstructedAsyncValueRef>( + std::move(event)); Rendezvous::Args args; // Rendezvous::Args owns the device context pointer. diff --git a/tensorflow/compiler/jit/node_matchers_test.cc b/tensorflow/compiler/jit/node_matchers_test.cc index 8edb3e456c4c00..6f37d5617b6ce6 100644 --- a/tensorflow/compiler/jit/node_matchers_test.cc +++ b/tensorflow/compiler/jit/node_matchers_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/jit/node_matchers.h" +#include + #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" @@ -117,12 +119,26 @@ TEST(NodeMatchers, CheckControlDependence) { EXPECT_THAT(placeholder_d.node(), NodeWith(Name("placeholder_d"), CtrlDeps())); - EXPECT_EQ( - Explain(placeholder_c.node(), NodeWith(CtrlDeps())), - "ctrl_deps, which has 2 elements, does not match expected: is empty"); - EXPECT_EQ(Explain(placeholder_d.node(), NodeWith(CtrlDeps(NodeWith()))), - "ctrl_deps does not match expected: has 1 element and that element " - "is any node"); + // TODO(griffithjames): Exactly match these explanations. + // + // When the OSS build has been updated to include the new error messages, the + // Explain() expectations can be exact strings again. + { + const std::string explanation = + Explain(placeholder_c.node(), NodeWith(CtrlDeps())); + EXPECT_NE(explanation.find("ctrl_deps, which has 2 elements"), + std::string::npos); + EXPECT_NE(explanation.find("does not match expected: is empty"), + std::string::npos); + } + { + const std::string explanation = + Explain(placeholder_d.node(), NodeWith(CtrlDeps(NodeWith()))); + EXPECT_NE(explanation.find("ctrl_deps"), std::string::npos); + EXPECT_NE(explanation.find("does not match expected: has 1 element and " + "that element is any node"), + std::string::npos); + } } TEST(NodeMatchers, ConstValue) { diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index b5b0c16422ccab..471f54571d2b53 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -52,7 +52,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 821d294af90f66..faf3b65d407a7e 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/tensor_reference.h" +#include "tsl/platform/statusor.h" namespace tensorflow { @@ -171,8 +172,8 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, host_to_device_stream_.get(), literal, xla_tensor->shaped_buffer())); if (UseMultipleStreams()) { - auto event = std::make_shared(stream_->parent()); - TF_RET_CHECK(event->Init()) << "Event failed to initialize!"; + TF_ASSIGN_OR_RETURN(std::shared_ptr event, + stream_->parent()->CreateEvent()); TF_RETURN_IF_ERROR(host_to_device_stream_->RecordEvent(event.get())); xla_tensor->ResetDefinitionEvent(std::move(event), host_to_device_stream_.get()); diff --git a/tensorflow/compiler/jit/xla_host_recv_device_context.cc b/tensorflow/compiler/jit/xla_host_recv_device_context.cc index 54f22fe59fa0bf..ae3c149d5d1387 100644 --- a/tensorflow/compiler/jit/xla_host_recv_device_context.cc +++ b/tensorflow/compiler/jit/xla_host_recv_device_context.cc @@ -38,7 +38,7 @@ void XlaHostRecvDeviceContext::CopyDeviceTensorToCPU( done(status); return; } - status = stream_->RecordEvent(&done_event_.get()); + status = stream_->RecordEvent(done_event_.get().get()); if (!status.ok()) { done(status); return; diff --git a/tensorflow/compiler/jit/xla_host_recv_device_context.h b/tensorflow/compiler/jit/xla_host_recv_device_context.h index 8938fd9c9e0c17..028fd4efd68091 100644 --- a/tensorflow/compiler/jit/xla_host_recv_device_context.h +++ b/tensorflow/compiler/jit/xla_host_recv_device_context.h @@ -36,8 +36,8 @@ namespace tensorflow { // Tensor device_tensor(device_allocator, DT_FLOAT, TensorShape({2, 2})); // se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)}; // xla::Shape shape(xla::F32, {2, 2}, {}, {}) -// tsl::AsyncValueRef done_event = -// tsl::MakeConstructedAsyncValueRef(stream.parent()); +// tsl::AsyncValueRef> done_event = +// tsl::MakeConstructedAsyncValueRef>(stream.parent()); // done_event->Init(); // Tensor dest_cpu_tensor; // @@ -48,10 +48,10 @@ namespace tensorflow { class XlaHostRecvDeviceContext : public DeviceContext { public: - XlaHostRecvDeviceContext(se::Stream* stream, - const se::DeviceMemoryBase& device_memory_base, - const xla::Shape& shape, - tsl::AsyncValueRef& done_event) + XlaHostRecvDeviceContext( + se::Stream* stream, const se::DeviceMemoryBase& device_memory_base, + const xla::Shape& shape, + tsl::AsyncValueRef>& done_event) : stream_(stream), device_memory_base_(device_memory_base), shape_(shape), @@ -82,7 +82,7 @@ class XlaHostRecvDeviceContext : public DeviceContext { // not an issue here since only DeviceMemoryBase methods/members are used. const se::DeviceMemoryBase device_memory_base_; const xla::Shape shape_; - tsl::AsyncValueRef done_event_; + tsl::AsyncValueRef> done_event_; XlaHostRecvDeviceContext(const XlaHostRecvDeviceContext&) = delete; void operator=(const XlaHostRecvDeviceContext&) = delete; diff --git a/tensorflow/compiler/jit/xla_host_send_device_context.cc b/tensorflow/compiler/jit/xla_host_send_device_context.cc index 5d106c8dc3e073..3d1a9a9f5228c6 100644 --- a/tensorflow/compiler/jit/xla_host_send_device_context.cc +++ b/tensorflow/compiler/jit/xla_host_send_device_context.cc @@ -28,7 +28,7 @@ void XlaHostSendDeviceContext::CopyCPUTensorToDevice( done(status); return; } - status = stream_->RecordEvent(&done_event_.get()); + status = stream_->RecordEvent(done_event_.get().get()); if (!status.ok()) { done(status); return; diff --git a/tensorflow/compiler/jit/xla_host_send_device_context.h b/tensorflow/compiler/jit/xla_host_send_device_context.h index d7a254770c969e..f4e4e9a2535341 100644 --- a/tensorflow/compiler/jit/xla_host_send_device_context.h +++ b/tensorflow/compiler/jit/xla_host_send_device_context.h @@ -37,8 +37,8 @@ namespace tensorflow { // Tensor device_tensor(device_allocator, DT_FLOAT, TensorShape({2, 2})); // se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)}; // xla::Shape shape(xla::F32, {2, 2}, {}, {}) -// tsl::AsyncValueRef done_event = -// tsl::MakeConstructedAsyncValueRef(stream.parent()); +// tsl::AsyncValueRef> done_event = +// tsl::MakeConstructedAsyncValueRef>(stream.parent()); // done_event->Init(); // // XlaHostSendDeviceContext device_context(&stream, &gpu_dst, @@ -48,10 +48,10 @@ namespace tensorflow { class XlaHostSendDeviceContext : public DeviceContext { public: - XlaHostSendDeviceContext(se::Stream* stream, - se::DeviceMemoryBase* device_memory_base, - const xla::Shape& shape, - tsl::AsyncValueRef& done_event) + XlaHostSendDeviceContext( + se::Stream* stream, se::DeviceMemoryBase* device_memory_base, + const xla::Shape& shape, + tsl::AsyncValueRef>& done_event) : stream_(stream), device_memory_base_(device_memory_base), shape_(shape), @@ -79,7 +79,7 @@ class XlaHostSendDeviceContext : public DeviceContext { se::Stream* stream_; // Not owned. se::DeviceMemoryBase* device_memory_base_; // Not owned. const xla::Shape shape_; - tsl::AsyncValueRef done_event_; + tsl::AsyncValueRef> done_event_; XlaHostSendDeviceContext(const XlaHostSendDeviceContext&) = delete; void operator=(const XlaHostSendDeviceContext&) = delete; diff --git a/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc b/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc index 16f42d1dbe1a0d..62da04c3e7510f 100644 --- a/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc +++ b/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc @@ -79,9 +79,10 @@ TEST_F(XlaHostSendRecvDeviceContextTest, CopyDeviceTensorToCPU) { stream->Memcpy(&gpu_dst, origin_cpu_tensor.data(), gpu_dst.size())); TF_ASSERT_OK(stream->BlockHostUntilDone()); - tsl::AsyncValueRef done_event = - tsl::MakeConstructedAsyncValueRef(executor); - done_event->Init(); + TF_ASSERT_OK_AND_ASSIGN(auto se_event, executor->CreateEvent()); + tsl::AsyncValueRef> done_event = + tsl::MakeConstructedAsyncValueRef>( + std::move(se_event)); XlaHostRecvDeviceContext* device_context = new XlaHostRecvDeviceContext(stream.get(), gpu_dst, shape, done_event); TF_ASSERT_OK(device_context->CopyDeviceTensorToCPUSync( @@ -108,9 +109,10 @@ TEST_F(XlaHostSendRecvDeviceContextTest, CopyCPUTensorToDevice) { xla::Shape shape; TF_ASSERT_OK(TensorShapeToXLAShape(DT_FLOAT, TensorShape({2, 2}), &shape)); - tsl::AsyncValueRef done_event = - tsl::MakeConstructedAsyncValueRef(executor); - done_event->Init(); + TF_ASSERT_OK_AND_ASSIGN(auto se_event, executor->CreateEvent()); + tsl::AsyncValueRef> done_event = + tsl::MakeConstructedAsyncValueRef>( + std::move(se_event)); XlaHostSendDeviceContext* device_context = new XlaHostSendDeviceContext(stream.get(), &gpu_dst, shape, done_event); TF_ASSERT_OK(device_context->CopyCPUTensorToDeviceSync( @@ -141,17 +143,19 @@ TEST_F(XlaHostSendRecvDeviceContextTest, RoundTrip) { xla::Shape shape; TF_ASSERT_OK(TensorShapeToXLAShape(DT_FLOAT, TensorShape({2, 2}), &shape)); - tsl::AsyncValueRef send_done_event = - tsl::MakeConstructedAsyncValueRef(executor); - send_done_event->Init(); + TF_ASSERT_OK_AND_ASSIGN(auto se_event, executor->CreateEvent()); + tsl::AsyncValueRef> send_done_event = + tsl::MakeConstructedAsyncValueRef>( + std::move(se_event)); XlaHostSendDeviceContext* send_device_context = new XlaHostSendDeviceContext( stream.get(), &gpu_dst, shape, send_done_event); TF_ASSERT_OK(send_device_context->CopyCPUTensorToDeviceSync( &origin_cpu_tensor, device_.get(), &device_tensor)); - tsl::AsyncValueRef recv_done_event = - tsl::MakeConstructedAsyncValueRef(executor); - recv_done_event->Init(); + TF_ASSERT_OK_AND_ASSIGN(auto recv_se_event, executor->CreateEvent()); + tsl::AsyncValueRef> recv_done_event = + tsl::MakeConstructedAsyncValueRef>( + std::move(recv_se_event)); XlaHostRecvDeviceContext* recv_device_context = new XlaHostRecvDeviceContext( stream.get(), gpu_dst, shape, recv_done_event); TF_ASSERT_OK(recv_device_context->CopyDeviceTensorToCPUSync( diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 9107e07b83bc21..cfeaa937024b32 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -390,10 +390,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( std::shared_ptr definition_event; if (use_multiple_streams_ && stream) { - definition_event = std::make_shared(stream->parent()); - if (!definition_event->Init()) { - return errors::Internal("Failed to initialize tensor definition event."); - } + TF_ASSIGN_OR_RETURN(definition_event, stream->parent()->CreateEvent()); TF_RETURN_IF_ERROR(stream->RecordEvent(definition_event.get())); } @@ -410,7 +407,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( if (output.on_host_shape().is_dynamic()) { const se::Platform* platform = nullptr; if (stream != nullptr) { - platform = stream->parent()->platform(); + platform = stream->parent()->GetPlatform(); } else { // Stream is not set for the host platform. TF_ASSIGN_OR_RETURN(platform, @@ -670,7 +667,8 @@ Status PreparePjRtExecutableArguments( std::unique_ptr pjrt_buffer = std::make_unique( device_shape, std::move(device_buffer), pjrt_client, - pjrt_device); + pjrt_device, + pjrt_device->default_memory_space().value_or(nullptr)); owned_args->push_back(std::move(pjrt_buffer)); args->push_back(owned_args->back().get()); } @@ -866,7 +864,7 @@ Status RunPjRtExecutable( pjrt_client->LookupAddressableDevice(pjrt_device_id)); gpu::GpuServingDeviceSelectorResource* device_selector_resource = nullptr; - if (device_type == DEVICE_GPU) { + if (device_type == DEVICE_GPU && gpu::kUseGpuServingDeviceSelector) { auto rm = ctx->resource_manager(); TF_RETURN_IF_ERROR(rm->LookupOrCreate< gpu::GpuServingDeviceSelectorResource>( diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index 1486340e95da3b..a6b066f7460168 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -377,7 +377,7 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) { auto device = static_cast(device_base); platform_id = device->tensorflow_accelerator_device_info() ->stream->parent() - ->platform() + ->GetPlatform() ->id(); } else if (XlaDevice::GetMetadataFromDevice(device_base, &xla_device_metadata) .ok()) { diff --git a/tensorflow/compiler/jit/xla_tpu_device.cc b/tensorflow/compiler/jit/xla_tpu_device.cc index dfedd586df69aa..403e6b17e6fc00 100644 --- a/tensorflow/compiler/jit/xla_tpu_device.cc +++ b/tensorflow/compiler/jit/xla_tpu_device.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/tpu/tpu_node_device_util.h" #include "tensorflow/core/tpu/virtual_device.h" +#include "tsl/platform/statusor.h" namespace tensorflow { namespace { @@ -271,9 +272,8 @@ void TpuDeviceToDeviceCopy(DeviceContext* src_dev_context, dst_xla_context->host_to_device_stream())); } - auto definition_event = - std::make_shared(dst_xla_context->stream()->parent()); - TF_RET_CHECK(definition_event->Init()) << "Event failed to initialize!"; + TF_ASSIGN_OR_RETURN(std::shared_ptr definition_event, + dst_xla_context->stream()->parent()->CreateEvent()); TF_RETURN_IF_ERROR( dst_device_to_device_stream->RecordEvent(definition_event.get())); xla_output->ResetDefinitionEvent(std::move(definition_event), diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index d0286e5acff9ce..46d5e7e9fb9005 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -37,6 +37,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -51,7 +52,6 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tf_tfl_passes", # buildcleaner:keep "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration", "//tensorflow/compiler/mlir/tensorflow:mlprogram_util", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_test_passes", @@ -75,7 +75,6 @@ cc_library( "@local_xla//xla/mlir/framework/ir:xla_framework", "@local_xla//xla/mlir/framework/transforms:passes", "@local_xla//xla/mlir_hlo:all_passes", - "@local_xla//xla/service/cpu:hlo_xla_runtime_pipeline", ], ) @@ -190,7 +189,6 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration", "//tensorflow/compiler/mlir/tensorflow:mlprogram_util", "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "@llvm-project//mlir:AllExtensions", @@ -204,7 +202,6 @@ cc_library( "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir/framework/ir:xla_framework", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", - "@local_xla//xla/service/cpu:hlo_xla_runtime_pipeline", "@stablehlo//:register", ], ) @@ -229,7 +226,6 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow:translate_registration", - "//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op", "//tensorflow/core:lib", "//tensorflow/core:tensorflow", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/mlir/init_mlir.cc b/tensorflow/compiler/mlir/init_mlir.cc index 938cd52359b9d6..ce7cefabcdcf73 100644 --- a/tensorflow/compiler/mlir/init_mlir.cc +++ b/tensorflow/compiler/mlir/init_mlir.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/init_mlir.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/PrettyStackTrace.h" #include "tensorflow/core/platform/init_main.h" diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index c3826f1bfb935c..7e49b1d028ce69 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -19,6 +19,7 @@ package_group( "//third_party/odml/infra/...", "//tensorflow/compiler/mlir/...", "//tensorflow/lite/python/...", + "//waymo/accelerator/alpine/tools/...", "//waymo/ml/compiler/mlir/...", # Allow visibility from the mlir language server. "//learning/brain/mlir/mlir_lsp_server/...", @@ -310,6 +311,7 @@ cc_library( "@llvm-project//mlir:Dialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -360,6 +362,7 @@ cc_library( ":tensorflow_lite_ops_inc_gen", ":tensorflow_lite_passes_inc_gen", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow", @@ -369,7 +372,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_traits", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:framework", - "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", @@ -702,6 +704,8 @@ cc_library( ":variables_utils", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/lite/stablehlo:optimize_layout", "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", "//tensorflow/compiler/mlir/lite/stablehlo:tfl_legalize_hlo", "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", @@ -725,7 +729,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:tensor_list", - "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -860,6 +863,7 @@ cc_library( deps = [ "convert_type", ":op_quant_spec_getters_inc", + ":stateful_ops_utils", ":tensorflow_lite", ":tensorflow_lite_passes_inc_gen", ":tensorflow_lite_post_quantize_inc_gen", @@ -867,6 +871,7 @@ cc_library( ":validators", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/quantization/common:uniform_quantized_types", "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", @@ -875,7 +880,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:logging", - "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/tools/optimize:operator_property", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", @@ -910,6 +914,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], ) @@ -1020,6 +1025,8 @@ cc_library( ":convert_type", ":converter_inc", ":tensorflow_lite", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core/platform:errors", @@ -1027,8 +1034,6 @@ cc_library( "//tensorflow/core/platform:statusor", "//tensorflow/lite/core/c:private_common", "//tensorflow/lite/kernels/internal:kernel_utils", - "//tensorflow/lite/schema:schema_fbs", - "//tensorflow/lite/schema:schema_fbs_with_mutable", "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -1036,6 +1041,7 @@ cc_library( "@llvm-project//llvm:Analysis", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@local_tsl//tsl/platform:status", "@local_xla//xla:statusor", @@ -1049,8 +1055,8 @@ tf_native_cc_binary( name = "flatbuffer_to_string", srcs = ["flatbuffer_to_string.cc"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_reflection", "//tensorflow/lite/core:model_builder", - "//tensorflow/lite/schema:schema_fbs_with_reflection", "@flatbuffers", ], ) @@ -1059,26 +1065,11 @@ tf_native_cc_binary( name = "json_to_flatbuffer", srcs = ["json_to_flatbuffer.cc"], deps = [ - "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "@flatbuffers", ], ) -cc_library( - name = "emit_error_reporter", - srcs = [ - "emit_error_reporter.cc", - ], - hdrs = [ - "emit_error_reporter.h", - ], - deps = [ - "//tensorflow/lite/core/api", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - ], -) - cc_library( name = "flatbuffer_export", srcs = [ @@ -1097,6 +1088,7 @@ cc_library( "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", @@ -1114,7 +1106,6 @@ cc_library( "//tensorflow/lite/experimental/remat:metadata_util", "//tensorflow/lite/python/metrics:converter_error_data_proto_cc", "//tensorflow/lite/schema:schema_conversion_utils", - "//tensorflow/lite/schema:schema_fbs_with_mutable", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/tools/versioning", "//tensorflow/lite/tools/versioning:gpu_compatibility", @@ -1124,6 +1115,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", "@flatbuffers", "@llvm-project//llvm:Support", @@ -1156,6 +1148,7 @@ cc_library( ":size_utils", ":tensorflow_lite", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_to_vhlo_pass", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", @@ -1170,7 +1163,6 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/lite:framework", "//tensorflow/lite/experimental/remat:metadata_util", - "//tensorflow/lite/schema:schema_fbs_with_mutable", "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1203,11 +1195,12 @@ cc_library( ], deps = [ ":tensorflow_lite", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", - "//tensorflow/lite/schema:schema_fbs", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_xla//xla:statusor", ], ) @@ -1224,8 +1217,8 @@ cc_library( ":flatbuffer_export", ":flatbuffer_import", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/core:protos_all_cc", - "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/toco:toco_flags_proto_cc", "@com_google_absl//absl/strings", "@llvm-project//mlir:FuncDialect", @@ -1314,8 +1307,8 @@ tf_cc_binary( ":tf_to_tfl_flatbuffer", "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", - "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton_impl", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", @@ -1325,7 +1318,6 @@ tf_cc_binary( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", "//tensorflow/lite:framework", - "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1391,6 +1383,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/stablehlo:tfl_legalize_hlo", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/stablehlo:transforms", "//tensorflow/compiler/mlir/lite/stablehlo:uniform_quantized_stablehlo_to_tfl_pass", + "//tensorflow/compiler/mlir/lite/stablehlo/odml_converter:outline_composites", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", @@ -1423,6 +1416,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/metrics:error_collector", "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", "//tensorflow/compiler/mlir/lite/quantization/stablehlo:quantization", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_to_vhlo_pass", "//tensorflow/compiler/mlir/lite/stablehlo:op_stat_pass", @@ -1449,7 +1443,6 @@ cc_library( "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/experimental/remat:metadata_util", "//tensorflow/lite/python/metrics:converter_error_data_proto_cc", - "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/lite/tools/optimize:reduced_precision_support", @@ -1498,11 +1491,11 @@ cc_library( deps = [ ":convert_type", ":low_bit_utils", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:protos_all_cc", "//tensorflow/lite:string_util", - "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/base", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 69ec0bbbcee3dc..1149d7841b38fd 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -101,7 +101,7 @@ struct PassConfig { bool enable_stablehlo_quantizer = false; // Enables the attempt to directly lower composites into tflite ops. - bool enable_composite_direct_lowering = false; + bool enable_composite_direct_lowering = true; }; inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, diff --git a/tensorflow/compiler/mlir/lite/emit_error_reporter.h b/tensorflow/compiler/mlir/lite/emit_error_reporter.h deleted file mode 100644 index 9e9a5925600fc2..00000000000000 --- a/tensorflow/compiler/mlir/lite/emit_error_reporter.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EMIT_ERROR_REPORTER_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EMIT_ERROR_REPORTER_H_ - -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "tensorflow/lite/core/api/error_reporter.h" - -namespace tflite { - -// Error reporter that reports errors via the module's emitError. -class EmitErrorReporter : public ErrorReporter { - public: - explicit EmitErrorReporter(mlir::ModuleOp module) : module_(module) {} - int Report(const char* format, va_list args) override; - - private: - mlir::ModuleOp module_; -}; - -} // namespace tflite - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EMIT_ERROR_REPORTER_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc b/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc index 5f77797b9aa8a7..59cc28f9fa0608 100644 --- a/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc +++ b/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc @@ -44,7 +44,7 @@ namespace common { bool IsConstantOrNone(Operation* op) { return (op->getNumResults() == 1 && - op->getResult(0).getType().isa()) || + mlir::isa(op->getResult(0).getType())) || matchPattern(op, m_Constant()) || isa(op); } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD index 248a55c7fe17e1..b7c6eb7055221e 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD @@ -42,6 +42,7 @@ cc_library( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -88,6 +89,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", ], ) @@ -282,7 +284,6 @@ cc_library( deps = [ ":target_aware_conversion", "//tensorflow/compiler/mlir:tf_mlir_opt_main", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration", ], alwayslink = 1, ) @@ -324,7 +325,6 @@ cc_library( "//tensorflow/compiler/mlir/lite/experimental/tac/utils", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/core:lib", "@com_google_absl//absl/status", diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h b/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h index 40f4902e655bcd..88382e8cf6f27b 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" @@ -52,7 +53,7 @@ bool NotTFLQuantDequantizeOp(Operation* op); // Returns true if it is a shaped type of f32 elements. inline bool IsF32ShapedType(Type t) { - if (auto shaped_type = t.dyn_cast_or_null()) { + if (auto shaped_type = mlir::dyn_cast_or_null(t)) { return shaped_type.getElementType().isF32(); } return false; diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.cc b/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.cc index 9b2458571f0c34..11a1b31e5102de 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.cc @@ -29,6 +29,7 @@ #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/runtime_metadata_generated.h" @@ -82,8 +83,7 @@ std::optional> GetPerDeviceCosts( for (const auto& kv : hardware_map) { auto cost_attr = device_costs_attr.getNamed(kv.first); if (!cost_attr.has_value()) return std::nullopt; - float cost = cost_attr->getValue() - .dyn_cast_or_null() + float cost = mlir::dyn_cast_or_null(cost_attr->getValue()) .getValueAsDouble(); device_costs[kv.second] = cost; } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/device-transform-gpu.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/device-transform-gpu.mlir index c9a3999dad0a68..5ee1a71e344933 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/device-transform-gpu.mlir +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/device-transform-gpu.mlir @@ -11,7 +11,7 @@ func.func @pack(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1xf32> { // CHECK-DAG: %[[VAL_4:.*]] = "tfl.pseudo_const"{{.*}}dense<[2, 1]> : tensor<2xi32> // CHECK: %[[VAL_5:.*]] = "tfl.reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<1xf32>, tensor<4xi32>) -> tensor<1x1x1x1xf32> // CHECK: %[[VAL_6:.*]] = "tfl.reshape"(%[[VAL_1]], %[[VAL_2]]) : (tensor<1xf32>, tensor<4xi32>) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_7:.*]] = "tfl.concatenation"(%[[VAL_5]], %[[VAL_6]]) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x2xf32> +// CHECK: %[[VAL_7:.*]] = "tfl.concatenation"(%[[VAL_5]], %[[VAL_6]]) <{axis = 3 : i32, fused_activation_function = "NONE"}> : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x2xf32> // CHECK: %[[VAL_8:.*]] = "tfl.reshape"(%[[VAL_7]], %[[VAL_3]]) : (tensor<1x1x1x2xf32>, tensor<1xi32>) -> tensor<2xf32> // CHECK: %[[VAL_9:.*]] = "tfl.reshape"(%[[VAL_8]], %[[VAL_4]]) : (tensor<2xf32>, tensor<2xi32>) -> tensor<2x1xf32> // CHECK: return %[[VAL_9]] : tensor<2x1xf32> @@ -124,8 +124,8 @@ func.func @sub(%arg0: tensor<1x384x384x3xf32>, %arg1: tensor<3xf32>) -> tensor<1 // CHECK: func @sub(%[[VAL_0:.*]]: tensor<1x384x384x3xf32>, %[[VAL_1:.*]]: tensor<3xf32>) -> tensor<1x384x384x3xf32> { // CHECK: %[[VAL_2:.*]] = arith.constant dense<-1.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = tfl.mul(%[[VAL_1]], %[[VAL_2]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> -// CHECK: %[[VAL_4:.*]] = tfl.add(%[[VAL_0]], %[[VAL_3]]) {fused_activation_function = "NONE"} : (tensor<1x384x384x3xf32>, tensor<3xf32>) -> tensor<1x384x384x3xf32> +// CHECK: %[[VAL_3:.*]] = tfl.mul(%[[VAL_1]], %[[VAL_2]]) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor) -> tensor<3xf32> +// CHECK: %[[VAL_4:.*]] = tfl.add(%[[VAL_0]], %[[VAL_3]]) <{fused_activation_function = "NONE"}> : (tensor<1x384x384x3xf32>, tensor<3xf32>) -> tensor<1x384x384x3xf32> // CHECK: return %[[VAL_4]] : tensor<1x384x384x3xf32> // CHECK: } @@ -139,7 +139,7 @@ func.func @ensureBiasForConv2d(%arg0: tensor<128x32x32x3xf32>, %arg1: tensor<32x // CHECK: func @ensureBiasForConv2d(%[[VAL_0:.*]]: tensor<128x32x32x3xf32>, %[[VAL_1:.*]]: tensor<32x1x1x3xf32>) -> tensor<128x32x32x32xf32> { // CHECK: %[[VAL_2:.*]] = "tfl.pseudo_const"{{.*}}dense<0.000000e+00> : tensor<32xf32> -// CHECK: %[[VAL_3:.*]] = "tfl.conv_2d"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<128x32x32x3xf32>, tensor<32x1x1x3xf32>, tensor<32xf32>) -> tensor<128x32x32x32xf32> +// CHECK: %[[VAL_3:.*]] = "tfl.conv_2d"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<128x32x32x3xf32>, tensor<32x1x1x3xf32>, tensor<32xf32>) -> tensor<128x32x32x32xf32> // CHECK: return %[[VAL_3]] : tensor<128x32x32x32xf32> // CHECK: } @@ -156,7 +156,7 @@ func.func @padSliceTo4D(%arg0: tensor<4x384x32xf32>) -> tensor<1x384x32xf32> { // CHECK-DAG: %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[1, 1, 384, 32]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: %[[VAL_3:.*]] = "tfl.pseudo_const"{{.*}}dense<[1, 4, 384, 32]> : tensor<4xi32> -// CHECK-DAG: %[[VAL_4:.*]] = "tfl.pseudo_const"() {value = dense<[1, 384, 32]> : tensor<3xi32> +// CHECK-DAG: %[[VAL_4:.*]] = "tfl.pseudo_const"() <{value = dense<[1, 384, 32]> : tensor<3xi32> // CHECK: %[[VAL_5:.*]] = "tfl.reshape"(%[[VAL_0]], %[[VAL_3]]) : (tensor<4x384x32xf32>, tensor<4xi32>) -> tensor<1x4x384x32xf32> // CHECK: %[[VAL_6:.*]] = "tfl.slice"(%[[VAL_5]], %[[VAL_1]], %[[VAL_2]]) : (tensor<1x4x384x32xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x1x384x32xf32> // CHECK: %[[VAL_7:.*]] = "tfl.reshape"(%[[VAL_6]], %[[VAL_4]]) : (tensor<1x1x384x32xf32>, tensor<3xi32>) -> tensor<1x384x32xf32> @@ -189,7 +189,7 @@ func.func @fullyConnectedToConv(%arg0: tensor<384x384xf32>, %arg1: tensor<512x38 // CHECK-DAG: %[[VAL_5:.*]] = "tfl.pseudo_const"{{.*}}dense<[384, 512]> : tensor<2xi32> // CHECK: %[[VAL_6:.*]] = "tfl.reshape"(%[[VAL_0]], %[[VAL_3]]) : (tensor<384x384xf32>, tensor<4xi32>) -> tensor<1x1x384x384xf32> // CHECK: %[[VAL_7:.*]] = "tfl.reshape"(%[[VAL_1]], %[[VAL_4]]) : (tensor<512x384xf32>, tensor<4xi32>) -> tensor<512x1x1x384xf32> -// CHECK: %[[VAL_8:.*]] = "tfl.conv_2d"(%[[VAL_6]], %[[VAL_7]], %[[VAL_2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x1x384x384xf32>, tensor<512x1x1x384xf32>, tensor<512xf32>) -> tensor<1x1x384x512xf32> +// CHECK: %[[VAL_8:.*]] = "tfl.conv_2d"(%[[VAL_6]], %[[VAL_7]], %[[VAL_2]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x1x384x384xf32>, tensor<512x1x1x384xf32>, tensor<512xf32>) -> tensor<1x1x384x512xf32> // CHECK: %[[VAL_9:.*]] = "tfl.reshape"(%[[VAL_8]], %[[VAL_5]]) : (tensor<1x1x384x512xf32>, tensor<2xi32>) -> tensor<384x512xf32> // CHECK: return %[[VAL_9]] : tensor<384x512xf32> // CHECK: } @@ -208,7 +208,7 @@ func.func @padConcatTo4D(%arg0: tensor<384x384xf32>, %arg1: tensor<384x384xf32>, // CHECK: %[[VAL_7:.*]] = "tfl.reshape"(%[[VAL_1]], %[[VAL_4]]) : (tensor<384x384xf32>, tensor<4xi32>) -> tensor<1x1x384x384xf32> // CHECK: %[[VAL_8:.*]] = "tfl.reshape"(%[[VAL_2]], %[[VAL_4]]) : (tensor<384x384xf32>, tensor<4xi32>) -> tensor<1x1x384x384xf32> // CHECK: %[[VAL_9:.*]] = "tfl.reshape"(%[[VAL_3]], %[[VAL_4]]) : (tensor<384x384xf32>, tensor<4xi32>) -> tensor<1x1x384x384xf32> -// CHECK: %[[VAL_10:.*]] = "tfl.concatenation"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]]) {axis = 2 : i32, fused_activation_function = "NONE"} : (tensor<1x1x384x384xf32>, tensor<1x1x384x384xf32>, tensor<1x1x384x384xf32>, tensor<1x1x384x384xf32>) -> tensor<1x1x1536x384xf32> +// CHECK: %[[VAL_10:.*]] = "tfl.concatenation"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]]) <{axis = 2 : i32, fused_activation_function = "NONE"}> : (tensor<1x1x384x384xf32>, tensor<1x1x384x384xf32>, tensor<1x1x384x384xf32>, tensor<1x1x384x384xf32>) -> tensor<1x1x1536x384xf32> // CHECK: %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_5]]) : (tensor<1x1x1536x384xf32>, tensor<2xi32>) -> tensor<1536x384xf32> // CHECK: return %[[VAL_11]] : tensor<1536x384xf32> // CHECK: } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/device-transform-nnapi.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/device-transform-nnapi.mlir index 8918291711354e..41e57494486187 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/device-transform-nnapi.mlir +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/device-transform-nnapi.mlir @@ -7,7 +7,7 @@ func.func @mean_4d_keepdim(%arg0: tensor<1x48x48x512xf32>) -> tensor<1x1x1x512xf } // CHECK: func @mean_4d_keepdim([[VAL_0:%.*]]: tensor<1x48x48x512xf32>) -> tensor<1x1x1x512xf32> { -// CHECK: [[VAL_1:%.*]] = "tfl.average_pool_2d"([[VAL_0]]) {filter_height = 48 : i32, filter_width = 48 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x48x48x512xf32>) -> tensor<1x1x1x512xf32> +// CHECK: [[VAL_1:%.*]] = "tfl.average_pool_2d"([[VAL_0]]) <{filter_height = 48 : i32, filter_width = 48 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x48x48x512xf32>) -> tensor<1x1x1x512xf32> // CHECK: return [[VAL_1]] : tensor<1x1x1x512xf32> // CHECK: } @@ -21,7 +21,7 @@ func.func @mean_4d_no_keepdim(%arg0: tensor<1x48x48x512xf32>) -> tensor<1x512xf3 // CHECK: func @mean_4d_no_keepdim([[VAL_0:%.*]]: tensor<1x48x48x512xf32>) -> tensor<1x512xf32> { // CHECK: [[VAL_1:%.*]] = "tfl.pseudo_const"(){{.*}}dense<[1, 512]> : tensor<2xi32> -// CHECK: [[VAL_2:%.*]] = "tfl.average_pool_2d"([[VAL_0]]) {filter_height = 48 : i32, filter_width = 48 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x48x48x512xf32>) -> tensor<1x1x1x512xf32> +// CHECK: [[VAL_2:%.*]] = "tfl.average_pool_2d"([[VAL_0]]) <{filter_height = 48 : i32, filter_width = 48 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x48x48x512xf32>) -> tensor<1x1x1x512xf32> // CHECK: [[VAL_3:%.*]] = "tfl.reshape"([[VAL_2]], [[VAL_1]]) : (tensor<1x1x1x512xf32>, tensor<2xi32>) -> tensor<1x512xf32> // CHECK: return [[VAL_3]] : tensor<1x512xf32> // CHECK: } @@ -36,7 +36,7 @@ func.func @mean_quant_same_scale(%arg0: tensor>) -> tensor> { // CHECK: %[[VAL_1:.*]] = "tfl.pseudo_const"(){{.*}}dense<[-1, 2048]> : tensor<2xi32> -// CHECK: %[[VAL_2:.*]] = "tfl.average_pool_2d"(%[[VAL_0]]) {filter_height = 7 : i32, filter_width = 7 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor>) -> tensor> +// CHECK: %[[VAL_2:.*]] = "tfl.average_pool_2d"(%[[VAL_0]]) <{filter_height = 7 : i32, filter_width = 7 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor>) -> tensor> // CHECK: %[[VAL_3:.*]] = "tfl.reshape"(%[[VAL_2]], %[[VAL_1]]) : (tensor>, tensor<2xi32>) -> tensor> // CHECK: return %[[VAL_3]] : tensor> // CHECK: } @@ -51,8 +51,8 @@ func.func @mean_quant_different_scales(%arg0: tensor>) -> tensor> { // CHECK: %[[VAL_1:.*]] = "tfl.pseudo_const"(){{.*}}dense<[-1, 2048]> : tensor<2xi32> -// CHECK: %[[VAL_2:.*]] = "tfl.average_pool_2d"(%[[VAL_0]]) {filter_height = 7 : i32, filter_width = 7 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor>) -> tensor> +// CHECK: %[[VAL_2:.*]] = "tfl.average_pool_2d"(%[[VAL_0]]) <{filter_height = 7 : i32, filter_width = 7 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor>) -> tensor> // CHECK: %[[VAL_3:.*]] = "tfl.reshape"(%[[VAL_2]], %[[VAL_1]]) : (tensor>, tensor<2xi32>) -> tensor> -// CHECK: %[[VAL_4:.*]] = "tfl.quantize"(%[[VAL_3]]) {qtype = tensor>} : (tensor>) -> tensor> +// CHECK: %[[VAL_4:.*]] = "tfl.quantize"(%[[VAL_3]]) <{qtype = tensor>}> : (tensor>) -> tensor> // CHECK: return %[[VAL_4]] : tensor> // CHECK: } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/device-transform-nnapi.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/device-transform-nnapi.mlir index a69fc368ebcda4..55adc81a2e5713 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/device-transform-nnapi.mlir +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/device-transform-nnapi.mlir @@ -14,7 +14,7 @@ module { %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32> func.return %0 : tensor<2x1xf32> // CHECK: %[[VAL_0:.*]] = arith.constant dense<[2, 1]> : tensor<2xi32> - // CHECK: %[[CONCAT:.*]] = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<2xf32> + // CHECK: %[[CONCAT:.*]] = "tfl.concatenation"(%arg0, %arg1) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<2xf32> // CHECK: %[[VAL_1:.*]] = "tfl.reshape"(%[[CONCAT]], %[[VAL_0]]) : (tensor<2xf32>, tensor<2xi32>) -> tensor<2x1xf32> // CHECK: return %[[VAL_1]] } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/simple-graph.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/simple-graph.mlir index b92d6c7a6f0103..a8c5a5f2ff1ee2 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/simple-graph.mlir +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/simple-graph.mlir @@ -12,7 +12,7 @@ func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32> // CHECK: %[[CST:.*]] = arith.constant dense<1> : tensor<4xi32> // CHECK: [[VAL_0:%.*]] = "tfl.reshape"(%1, %[[CST]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1xf32>, tensor<4xi32>) -> tensor<1x1x1x1xf32> // CHECK: [[VAL_1:%.*]] = "tfl.reshape"(%2, %[[CST]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1xf32>, tensor<4xi32>) -> tensor<1x1x1x1xf32> -// CHECK: [[VAL_2:%.*]] = "tfl.concatenation"([[VAL_0]], [[VAL_1]]) {axis = 3 : i32, fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x2xf32> +// CHECK: [[VAL_2:%.*]] = "tfl.concatenation"([[VAL_0]], [[VAL_1]]) <{axis = 3 : i32, fused_activation_function = "NONE"}> {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x2xf32> // CHECK: [[VAL_3:%.*]] = "tfl.reshape"([[VAL_2]], %{{.*}}) : (tensor<1x1x1x2xf32>, tensor<2xi32>) -> tensor<2x1xf32> } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/fold-constants-to-subgraph.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/fold-constants-to-subgraph.mlir index e8a30755a8c768..12a3b14e5f894f 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/fold-constants-to-subgraph.mlir +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/fold-constants-to-subgraph.mlir @@ -20,8 +20,8 @@ func.func @simple_test(%arg0: tensor<4x384x32xf32>, %arg1: tensor<3xi32>, %arg2: } // PARTIAL: func @simple_test(%[[VAL_0:.*]]: tensor<4x384x32xf32>, %[[VAL_1:.*]]: tensor<3xi32>, %[[VAL_2:.*]]: tensor<3xi32>) -> tensor<1x384x32xf32> attributes {tac.interface_name = "func1"} { -// PARTIAL: %[[VAL_3:.*]] = "tfl.pseudo_const"() {value = dense<[1, 384, 32]> : tensor<3xi32>} : () -> tensor<3xi32> -// PARTIAL: %[[VAL_4:.*]] = "tfl.pseudo_const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> +// PARTIAL: %[[VAL_3:.*]] = "tfl.pseudo_const"() <{value = dense<[1, 384, 32]> : tensor<3xi32>}> : () -> tensor<3xi32> +// PARTIAL: %[[VAL_4:.*]] = "tfl.pseudo_const"() <{value = dense<0> : tensor<3xi32>}> : () -> tensor<3xi32> // PARTIAL: %[[VAL_5:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_4]], %[[VAL_3]]) : (tensor<4x384x32xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x384x32xf32> // PARTIAL: return %[[VAL_5]] : tensor<1x384x32xf32> // PARTIAL: } @@ -52,15 +52,15 @@ func.func @arg_reuse_test_2(%arg0: tensor<4x384x32xf32>, %arg1: tensor<3xi32>, % } // PARTIAL: func @arg_reuse_test_1(%[[VAL_0:.*]]: tensor<4x384x32xf32>, %[[VAL_1:.*]]: tensor<3xi32>, %[[VAL_2:.*]]: tensor<3xi32>) -> tensor<1x384x32xf32> attributes {tac.interface_name = "func1"} { -// PARTIAL: %[[VAL_3:.*]] = "tfl.pseudo_const"() {value = dense<[1, 384, 32]> : tensor<3xi32>} : () -> tensor<3xi32> -// PARTIAL: %[[VAL_4:.*]] = "tfl.pseudo_const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> +// PARTIAL: %[[VAL_3:.*]] = "tfl.pseudo_const"() <{value = dense<[1, 384, 32]> : tensor<3xi32>}> : () -> tensor<3xi32> +// PARTIAL: %[[VAL_4:.*]] = "tfl.pseudo_const"() <{value = dense<0> : tensor<3xi32>}> : () -> tensor<3xi32> // PARTIAL: %[[VAL_5:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_4]], %[[VAL_3]]) : (tensor<4x384x32xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x384x32xf32> // PARTIAL: return %[[VAL_5]] : tensor<1x384x32xf32> // PARTIAL: } // PARTIAL: func @arg_reuse_test_2(%[[VAL_6:.*]]: tensor<4x384x32xf32>, %[[VAL_7:.*]]: tensor<3xi32>, %[[VAL_8:.*]]: tensor<3xi32>) -> tensor<1x384x32xf32> attributes {tac.interface_name = "func2"} { -// PARTIAL: %[[VAL_9:.*]] = "tfl.pseudo_const"() {value = dense<[1, 384, 32]> : tensor<3xi32>} : () -> tensor<3xi32> -// PARTIAL: %[[VAL_10:.*]] = "tfl.pseudo_const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> +// PARTIAL: %[[VAL_9:.*]] = "tfl.pseudo_const"() <{value = dense<[1, 384, 32]> : tensor<3xi32>}> : () -> tensor<3xi32> +// PARTIAL: %[[VAL_10:.*]] = "tfl.pseudo_const"() <{value = dense<0> : tensor<3xi32>}> : () -> tensor<3xi32> // PARTIAL: %[[VAL_11:.*]] = "tfl.slice"(%[[VAL_6]], %[[VAL_10]], %[[VAL_9]]) : (tensor<4x384x32xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x384x32xf32> // PARTIAL: return %[[VAL_11]] : tensor<1x384x32xf32> // PARTIAL: } @@ -84,8 +84,8 @@ func.func @quantization_test(%arg0: tensor<384x512x!quant.uniform>, } // PARTIAL: func @quantization_test(%[[VAL_0:.*]]: tensor<384x512x!quant.uniform>, %[[VAL_1:.*]]: tensor<128x512x!quant.uniform:f32, 1.000000e-02>>, %[[VAL_2:.*]]: tensor<128x!quant.uniform>) -> tensor<384x128x!quant.uniform> { -// PARTIAL: %[[VAL_3:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<128x!quant.uniform>, value = dense<0> : tensor<128xi32>} : () -> tensor<128x!quant.uniform> -// PARTIAL: %[[VAL_4:.*]] = "tfl.fully_connected"(%[[VAL_0]], %[[VAL_1]], %[[VAL_3]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<384x512x!quant.uniform>, tensor<128x512x!quant.uniform:f32, 1.000000e-02>>, tensor<128x!quant.uniform>) -> tensor<384x128x!quant.uniform> +// PARTIAL: %[[VAL_3:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<128x!quant.uniform>, value = dense<0> : tensor<128xi32>}> : () -> tensor<128x!quant.uniform> +// PARTIAL: %[[VAL_4:.*]] = "tfl.fully_connected"(%[[VAL_0]], %[[VAL_1]], %[[VAL_3]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<384x512x!quant.uniform>, tensor<128x512x!quant.uniform:f32, 1.000000e-02>>, tensor<128x!quant.uniform>) -> tensor<384x128x!quant.uniform> // PARTIAL: return %[[VAL_4]] : tensor<384x128x!quant.uniform> // PARTIAL: } @@ -108,9 +108,9 @@ func.func @fold_all_test(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3x } // ALL: func @fold_all_test(%[[VAL_0:.*]]: tensor<256x32x32x3xf32>, %[[VAL_1:.*]]: tensor<16x3x3x3xf32>, %[[VAL_2:.*]]: tensor<16xf32>) -> tensor<256x30x30x16xf32> { -// ALL: %[[VAL_3:.*]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16xf32>} : () -> tensor<16xf32> -// ALL: %[[VAL_4:.*]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x3x3x3xf32>} : () -> tensor<16x3x3x3xf32> -// ALL: %[[VAL_5:.*]] = "tfl.conv_2d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_3]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32, tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> +// ALL: %[[VAL_3:.*]] = "tfl.pseudo_const"() <{value = dense<1.000000e+00> : tensor<16xf32>}> : () -> tensor<16xf32> +// ALL: %[[VAL_4:.*]] = "tfl.pseudo_const"() <{value = dense<1.000000e+00> : tensor<16x3x3x3xf32>}> : () -> tensor<16x3x3x3xf32> +// ALL: %[[VAL_5:.*]] = "tfl.conv_2d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_3]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> // ALL: return %[[VAL_5]] : tensor<256x30x30x16xf32> // ALL: } } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/get-alternative-subgraph.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/get-alternative-subgraph.mlir index 0e5101a5352b4c..80eecdeea1e44f 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/get-alternative-subgraph.mlir +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/get-alternative-subgraph.mlir @@ -32,7 +32,7 @@ module { // CHECK: } // CHECK: func private @func_2_CPU_FLOAT(%[[VAL_0:.*]]: tensor<1xf32>, %[[VAL_1:.*]]: tensor<1xf32>) -> tensor<2x1xf32> attributes {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_2"} { -// CHECK: %[[VAL_2:.*]] = "tfl.pack"(%[[VAL_0]], %[[VAL_1]]) {axis = 0 : i32, tac.device = "CPU", tac.inference_type = "FLOAT", values_count = 2 : i32} : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32> +// CHECK: %[[VAL_2:.*]] = "tfl.pack"(%[[VAL_0]], %[[VAL_1]]) <{axis = 0 : i32, values_count = 2 : i32}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32> // CHECK: return %[[VAL_2]] : tensor<2x1xf32> // CHECK: } @@ -53,7 +53,7 @@ module { // CHECK-DAG: %[[VAL_4:.*]] = "tfl.pseudo_const"(){{.*}}dense<[2, 1]> : tensor<2xi32> // CHECK: %[[VAL_5:.*]] = "tfl.reshape"(%[[VAL_0]], %[[VAL_2]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1xf32>, tensor<4xi32>) -> tensor<1x1x1x1xf32> // CHECK: %[[VAL_6:.*]] = "tfl.reshape"(%[[VAL_1]], %[[VAL_2]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1xf32>, tensor<4xi32>) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_7:.*]] = "tfl.concatenation"(%[[VAL_5]], %[[VAL_6]]) {axis = 3 : i32, fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x2xf32> +// CHECK: %[[VAL_7:.*]] = "tfl.concatenation"(%[[VAL_5]], %[[VAL_6]]) <{axis = 3 : i32, fused_activation_function = "NONE"}> {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x2xf32> // CHECK: %[[VAL_8:.*]] = "tfl.reshape"(%[[VAL_7]], %[[VAL_3]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x1x1x2xf32>, tensor<1xi32>) -> tensor<2xf32> // CHECK: %[[VAL_9:.*]] = "tfl.reshape"(%[[VAL_8]], %[[VAL_4]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<2xf32>, tensor<2xi32>) -> tensor<2x1xf32> // CHECK: return %[[VAL_9]] : tensor<2x1xf32> @@ -81,7 +81,7 @@ func.func private @func_10_CPU_FLOAT(%arg0: tensor<3xi32>, %arg1: tensor, % } // CHECK: func private @func_10_CPU_FLOAT(%[[VAL_0:.*]]: tensor<3xi32>, %[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor, %[[VAL_3:.*]]: tensor) -> tensor<*xf32> attributes {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_10"} { -// CHECK: %[[VAL_4:.*]] = "tfl.one_hot"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) {axis = -1 : i32, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xf32> +// CHECK: %[[VAL_4:.*]] = "tfl.one_hot"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) <{axis = -1 : i32}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xf32> // CHECK: return %[[VAL_4]] : tensor<*xf32> // CHECK: } @@ -121,11 +121,11 @@ func.func private @quantize_ops_CPU_QUANTIZED_INT8(%arg0: tensor<384x512x!quant. // CHECK: func private @quantize_ops_CPU_QUANTIZED_INT8(%[[VAL_0:.*]]: tensor<384x512x!quant.uniform>, %[[VAL_1:.*]]: tensor<128x512x!quant.uniform:f32, 1.000000e-01>>, %[[VAL_2:.*]]: tensor<128x!quant.uniform>, %[[VAL_3:.*]]: tensor<128x!quant.uniform>) -> tensor<1x384x128x!quant.uniform> attributes {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8", tac.interface_name = "quantize_ops"} { // CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<[1, 384, 128]> : tensor<3xi32> -// CHECK-DAG: %[[VAL_5:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<128x!quant.uniform>, value = dense<0> : tensor<128xi32>} : () -> tensor<128x!quant.uniform> -// CHECK: %[[VAL_6:.*]] = "tfl.fully_connected"(%[[VAL_0]], %[[VAL_1]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8", weights_format = "DEFAULT"} : (tensor<384x512x!quant.uniform>, tensor<128x512x!quant.uniform:f32, 1.000000e-01>>, tensor<128x!quant.uniform>) -> tensor<384x128x!quant.uniform> +// CHECK-DAG: %[[VAL_5:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<128x!quant.uniform>, value = dense<0> : tensor<128xi32>}> : () -> tensor<128x!quant.uniform> +// CHECK: %[[VAL_6:.*]] = "tfl.fully_connected"(%[[VAL_0]], %[[VAL_1]], %[[VAL_5]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : (tensor<384x512x!quant.uniform>, tensor<128x512x!quant.uniform:f32, 1.000000e-01>>, tensor<128x!quant.uniform>) -> tensor<384x128x!quant.uniform> // CHECK: %[[VAL_7:.*]] = "tfl.reshape"(%[[VAL_6]], %[[VAL_4]]) {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : (tensor<384x128x!quant.uniform>, tensor<3xi32>) -> tensor<1x384x128x!quant.uniform> -// CHECK: %[[VAL_8:.*]] = tfl.mul(%[[VAL_7]], %[[VAL_2]]) {fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : (tensor<1x384x128x!quant.uniform>, tensor<128x!quant.uniform>) -> tensor<1x384x128x!quant.uniform> -// CHECK: %[[VAL_9:.*]] = tfl.add(%[[VAL_8]], %[[VAL_3]]) {fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : (tensor<1x384x128x!quant.uniform>, tensor<128x!quant.uniform>) -> tensor<1x384x128x!quant.uniform> +// CHECK: %[[VAL_8:.*]] = tfl.mul(%[[VAL_7]], %[[VAL_2]]) <{fused_activation_function = "NONE"}> {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : (tensor<1x384x128x!quant.uniform>, tensor<128x!quant.uniform>) -> tensor<1x384x128x!quant.uniform> +// CHECK: %[[VAL_9:.*]] = tfl.add(%[[VAL_8]], %[[VAL_3]]) <{fused_activation_function = "NONE"}> {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : (tensor<1x384x128x!quant.uniform>, tensor<128x!quant.uniform>) -> tensor<1x384x128x!quant.uniform> // CHECK: return %[[VAL_9]] : tensor<1x384x128x!quant.uniform> // CHECK: } @@ -139,14 +139,14 @@ func.func private @quantize_ops_CPU_QUANTIZED_INT8(%arg0: tensor<384x512x!quant. // CHECK: %[[VAL_10:.*]] = "tfl.dequantize"(%[[VAL_1]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<128x512x!quant.uniform:f32, 1.000000e-01>>) -> tensor<128x512xf32> // CHECK: %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_9]], %[[VAL_6]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<384x512xf32>, tensor<4xi32>) -> tensor<1x1x384x512xf32> // CHECK: %[[VAL_12:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_7]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<128x512xf32>, tensor<4xi32>) -> tensor<128x1x1x512xf32> -// CHECK: %[[VAL_13:.*]] = "tfl.conv_2d"(%[[VAL_11]], %[[VAL_12]], %[[VAL_4]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32, tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x1x384x512xf32>, tensor<128x1x1x512xf32>, tensor<128xf32>) -> tensor<1x1x384x128xf32> +// CHECK: %[[VAL_13:.*]] = "tfl.conv_2d"(%[[VAL_11]], %[[VAL_12]], %[[VAL_4]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x1x384x512xf32>, tensor<128x1x1x512xf32>, tensor<128xf32>) -> tensor<1x1x384x128xf32> // CHECK: %[[VAL_14:.*]] = "tfl.reshape"(%[[VAL_13]], %[[VAL_8]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x1x384x128xf32>, tensor<2xi32>) -> tensor<384x128xf32> // CHECK: %[[VAL_15:.*]] = "tfl.reshape"(%[[VAL_14]], %[[VAL_5]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<384x128xf32>, tensor<3xi32>) -> tensor<1x384x128xf32> // CHECK: %[[VAL_16:.*]] = "tfl.dequantize"(%[[VAL_2]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<128x!quant.uniform>) -> tensor<128xf32> -// CHECK: %[[VAL_17:.*]] = tfl.mul(%[[VAL_15]], %[[VAL_16]]) {fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x384x128xf32>, tensor<128xf32>) -> tensor<1x384x128xf32> +// CHECK: %[[VAL_17:.*]] = tfl.mul(%[[VAL_15]], %[[VAL_16]]) <{fused_activation_function = "NONE"}> {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x384x128xf32>, tensor<128xf32>) -> tensor<1x384x128xf32> // CHECK: %[[VAL_18:.*]] = "tfl.dequantize"(%[[VAL_3]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<128x!quant.uniform>) -> tensor<128xf32> -// CHECK: %[[VAL_19:.*]] = tfl.add(%[[VAL_17]], %[[VAL_18]]) {fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x384x128xf32>, tensor<128xf32>) -> tensor<1x384x128xf32> -// CHECK: %[[VAL_20:.*]] = "tfl.quantize"(%[[VAL_19]]) {qtype = tensor<1x384x128x!quant.uniform>, tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x384x128xf32>) -> tensor<1x384x128x!quant.uniform> +// CHECK: %[[VAL_19:.*]] = tfl.add(%[[VAL_17]], %[[VAL_18]]) <{fused_activation_function = "NONE"}> {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x384x128xf32>, tensor<128xf32>) -> tensor<1x384x128xf32> +// CHECK: %[[VAL_20:.*]] = "tfl.quantize"(%[[VAL_19]]) <{qtype = tensor<1x384x128x!quant.uniform>}> {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x384x128xf32>) -> tensor<1x384x128x!quant.uniform> // CHECK: return %[[VAL_20]] : tensor<1x384x128x!quant.uniform> // CHECK: } @@ -155,13 +155,13 @@ func.func private @quantize_ops_CPU_QUANTIZED_INT8(%arg0: tensor<384x512x!quant. // CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<[1, 384, 128]> : tensor<3xi32> // CHECK: %[[VAL_6:.*]] = "tfl.dequantize"(%[[VAL_0]]) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<384x512x!quant.uniform>) -> tensor<384x512xf32> // CHECK: %[[VAL_7:.*]] = "tfl.dequantize"(%[[VAL_1]]) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<128x512x!quant.uniform:f32, 1.000000e-01>>) -> tensor<128x512xf32> -// CHECK: %[[VAL_8:.*]] = "tfl.fully_connected"(%[[VAL_6]], %[[VAL_7]], %[[VAL_4]]) {fused_activation_function = "NONE", keep_num_dims = false, tac.device = "CPU", tac.inference_type = "FLOAT", weights_format = "DEFAULT"} : (tensor<384x512xf32>, tensor<128x512xf32>, tensor<128xf32>) -> tensor<384x128xf32> +// CHECK: %[[VAL_8:.*]] = "tfl.fully_connected"(%[[VAL_6]], %[[VAL_7]], %[[VAL_4]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<384x512xf32>, tensor<128x512xf32>, tensor<128xf32>) -> tensor<384x128xf32> // CHECK: %[[VAL_9:.*]] = "tfl.reshape"(%[[VAL_8]], %[[VAL_5]]) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<384x128xf32>, tensor<3xi32>) -> tensor<1x384x128xf32> // CHECK: %[[VAL_10:.*]] = "tfl.dequantize"(%[[VAL_2]]) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<128x!quant.uniform>) -> tensor<128xf32> -// CHECK: %[[VAL_11:.*]] = tfl.mul(%[[VAL_9]], %[[VAL_10]]) {fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1x384x128xf32>, tensor<128xf32>) -> tensor<1x384x128xf32> +// CHECK: %[[VAL_11:.*]] = tfl.mul(%[[VAL_9]], %[[VAL_10]]) <{fused_activation_function = "NONE"}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1x384x128xf32>, tensor<128xf32>) -> tensor<1x384x128xf32> // CHECK: %[[VAL_12:.*]] = "tfl.dequantize"(%[[VAL_3]]) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<128x!quant.uniform>) -> tensor<128xf32> -// CHECK: %[[VAL_13:.*]] = tfl.add(%[[VAL_11]], %[[VAL_12]]) {fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1x384x128xf32>, tensor<128xf32>) -> tensor<1x384x128xf32> -// CHECK: %[[VAL_14:.*]] = "tfl.quantize"(%[[VAL_13]]) {qtype = tensor<1x384x128x!quant.uniform>, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1x384x128xf32>) -> tensor<1x384x128x!quant.uniform> +// CHECK: %[[VAL_13:.*]] = tfl.add(%[[VAL_11]], %[[VAL_12]]) <{fused_activation_function = "NONE"}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1x384x128xf32>, tensor<128xf32>) -> tensor<1x384x128xf32> +// CHECK: %[[VAL_14:.*]] = "tfl.quantize"(%[[VAL_13]]) <{qtype = tensor<1x384x128x!quant.uniform>}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1x384x128xf32>) -> tensor<1x384x128x!quant.uniform> // CHECK: return %[[VAL_14]] : tensor<1x384x128x!quant.uniform> // CHECK: } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/pick-subgraphs.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/pick-subgraphs.mlir index 0157e97a4e4ac3..b309fb513b1fe7 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/pick-subgraphs.mlir +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/pick-subgraphs.mlir @@ -93,12 +93,12 @@ module { } // CHECK: func @main([[VAL_0:%.*]]: tensor<1x200x200x200xf32>) -> tensor<2x1x200x200x200xf32> attributes {tf.entry_function = {inputs = "Placeholder", outputs = "mul_1"}} { -// CHECK: [[VAL_1:%.*]] = "tfl.pseudo_const"() {value = dense<0.962260901> : tensor<1xf32>} : () -> tensor<1xf32> +// CHECK: [[VAL_1:%.*]] = "tfl.pseudo_const"() <{value = dense<0.962260901> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: [[VAL_2:%.*]] = call @func_0_GPU_FLOAT([[VAL_0]], [[VAL_1]]) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} : (tensor<1x200x200x200xf32>, tensor<1xf32>) -> tensor<1x200x200x200xf32> -// CHECK: [[VAL_3:%.*]] = "tfl.pseudo_const"() {value = dense<0.895973444> : tensor<1xf32>} : () -> tensor<1xf32> +// CHECK: [[VAL_3:%.*]] = "tfl.pseudo_const"() <{value = dense<0.895973444> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: [[VAL_4:%.*]] = call @func_1_GPU_FLOAT([[VAL_0]], [[VAL_3]]) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_1"} : (tensor<1x200x200x200xf32>, tensor<1xf32>) -> tensor<1x200x200x200xf32> // CHECK: [[VAL_5:%.*]] = call @func_2_GPU_FLOAT([[VAL_4]], [[VAL_2]]) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_2"} : (tensor<1x200x200x200xf32>, tensor<1x200x200x200xf32>) -> tensor<2x1x200x200x200xf32> -// CHECK: [[VAL_6:%.*]] = "tfl.pseudo_const"() {value = dense<0.0778453499> : tensor<1xf32>} : () -> tensor<1xf32> +// CHECK: [[VAL_6:%.*]] = "tfl.pseudo_const"() <{value = dense<0.0778453499> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: [[VAL_7:%.*]] = call @func_3_GPU_FLOAT([[VAL_5]], [[VAL_6]]) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_3"} : (tensor<2x1x200x200x200xf32>, tensor<1xf32>) -> tensor<2x1x200x200x200xf32> // CHECK: return [[VAL_7]] : tensor<2x1x200x200x200xf32> // CHECK: } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir index 3018221fdacff6..0934bb387a22c5 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir @@ -116,7 +116,7 @@ func.func @simpleTest(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>, %arg2: tensor< // CHECK: } // CHECK: func private @func_1_CPU_FLOAT(%[[VAL_0:.*]]: tensor<1xf32>, %[[VAL_1:.*]]: tensor<1xf32>) -> tensor<2x1xf32> attributes {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_1"} { -// CHECK: %[[VAL_2:.*]] = "tfl.pack"(%[[VAL_0]], %[[VAL_1]]) {axis = 0 : i32, tac.device = "CPU", tac.inference_type = "FLOAT", values_count = 2 : i32} : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32> +// CHECK: %[[VAL_2:.*]] = "tfl.pack"(%[[VAL_0]], %[[VAL_1]]) <{axis = 0 : i32, values_count = 2 : i32}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32> // CHECK: return %[[VAL_2]] : tensor<2x1xf32> // CHECK: } @@ -134,17 +134,17 @@ func.func @constWeight(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf3 } // CHECK: func @constWeight(%[[VAL_0:.*]]: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x3x3x3xf32>} : () -> tensor<16x3x3x3xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16xf32>} : () -> tensor<16xf32> -// CHECK-DAG: %[[VAL_3:.*]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x3x3x16xf32>} : () -> tensor<16x3x3x16xf32> -// CHECK-DAG: %[[VAL_4:.*]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16xf32>} : () -> tensor<16xf32> +// CHECK-DAG: %[[VAL_1:.*]] = "tfl.pseudo_const"() <{value = dense<1.000000e+00> : tensor<16x3x3x3xf32>}> : () -> tensor<16x3x3x3xf32> +// CHECK-DAG: %[[VAL_2:.*]] = "tfl.pseudo_const"() <{value = dense<1.000000e+00> : tensor<16xf32>}> : () -> tensor<16xf32> +// CHECK-DAG: %[[VAL_3:.*]] = "tfl.pseudo_const"() <{value = dense<1.000000e+00> : tensor<16x3x3x16xf32>}> : () -> tensor<16x3x3x16xf32> +// CHECK-DAG: %[[VAL_4:.*]] = "tfl.pseudo_const"() <{value = dense<1.000000e+00> : tensor<16xf32>}> : () -> tensor<16xf32> // CHECK: %[[VAL_5:.*]] = call @func_0_GPU_FLOAT(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]]) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> // CHECK: return %[[VAL_5]] : tensor<256x30x30x16xf32> // CHECK: } // CHECK: func private @func_0_GPU_FLOAT(%[[VAL_0:.*]]: tensor<256x32x32x3xf32>, %[[VAL_1:.*]]: tensor<16x3x3x3xf32>, %[[VAL_2:.*]]: tensor<16xf32>, %[[VAL_3:.*]]: tensor<16x3x3x16xf32>, %[[VAL_4:.*]]: tensor<16xf32>) -> tensor<256x30x30x16xf32> attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { -// CHECK: %[[VAL_5:.*]] = "tfl.conv_2d"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32, tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> -// CHECK: %[[VAL_6:.*]] = "tfl.conv_2d"(%[[VAL_5]], %[[VAL_3]], %[[VAL_4]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32, tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<256x30x30x16xf32>, tensor<16x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> +// CHECK: %[[VAL_5:.*]] = "tfl.conv_2d"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> +// CHECK: %[[VAL_6:.*]] = "tfl.conv_2d"(%[[VAL_5]], %[[VAL_3]], %[[VAL_4]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<256x30x30x16xf32>, tensor<16x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> // CHECK: return %[[VAL_6]] : tensor<256x30x30x16xf32> // CHECK: } @@ -166,15 +166,15 @@ func.func @norm1(%arg0: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> { } // CHECK: func @norm1(%[[VAL_0:.*]]: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<128xf32>} : () -> tensor<128xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tfl.pseudo_const"() {value = dense<128> : tensor<2xi32>} : () -> tensor<2xi32> -// CHECK-DAG: %[[VAL_3:.*]] = "tfl.pseudo_const"() {value = dense<[1, 128, 128]> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK-DAG: %[[VAL_1:.*]] = "tfl.pseudo_const"() <{value = dense<1.000000e+00> : tensor<128xf32>}> : () -> tensor<128xf32> +// CHECK-DAG: %[[VAL_2:.*]] = "tfl.pseudo_const"() <{value = dense<128> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK-DAG: %[[VAL_3:.*]] = "tfl.pseudo_const"() <{value = dense<[1, 128, 128]> : tensor<3xi32>}> : () -> tensor<3xi32> // CHECK: %[[VAL_4:.*]] = call @func_0_GPU_FLOAT(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} : (tensor<1x128x128xf32>, tensor<128xf32>, tensor<2xi32>, tensor<3xi32>) -> tensor<1x128x128xf32> // CHECK: return %[[VAL_4]] : tensor<1x128x128xf32> // CHECK: } // CHECK: func private @func_0_GPU_FLOAT(%[[VAL_0:.*]]: tensor<1x128x128xf32>, %[[VAL_1:.*]]: tensor<128xf32>, %[[VAL_2:.*]]: tensor<2xi32>, %[[VAL_3:.*]]: tensor<3xi32>) -> tensor<1x128x128xf32> attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { -// CHECK: %[[VAL_4:.*]] = tfl.add(%[[VAL_0]], %[[VAL_1]]) {fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> +// CHECK: %[[VAL_4:.*]] = tfl.add(%[[VAL_0]], %[[VAL_1]]) <{fused_activation_function = "NONE"}> {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> // CHECK: %[[VAL_5:.*]] = "tfl.reshape"(%[[VAL_4]], %[[VAL_2]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x128x128xf32>, tensor<2xi32>) -> tensor<128x128xf32> // CHECK: %[[VAL_6:.*]] = "tfl.relu"(%[[VAL_5]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<128x128xf32>) -> tensor<128x128xf32> // CHECK: %[[VAL_7:.*]] = "tfl.reshape"(%[[VAL_6]], %[[VAL_3]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<128x128xf32>, tensor<3xi32>) -> tensor<1x128x128xf32> @@ -204,19 +204,19 @@ func.func @norm2(%arg0: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> { } // CHECK: func @norm2(%[[VAL_0:.*]]: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<128xf32>} : () -> tensor<128xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tfl.pseudo_const"() {value = dense<128> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[VAL_1:.*]] = "tfl.pseudo_const"() <{value = dense<1.000000e+00> : tensor<128xf32>}> : () -> tensor<128xf32> +// CHECK-DAG: %[[VAL_2:.*]] = "tfl.pseudo_const"() <{value = dense<128> : tensor<2xi32>}> : () -> tensor<2xi32> // CHECK: %[[VAL_3:.*]]:2 = call @func_0_GPU_FLOAT(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} : (tensor<1x128x128xf32>, tensor<128xf32>, tensor<2xi32>) -> (tensor<1x128x128xf32>, tensor<128x128xf32>) -// CHECK-DAG: %[[VAL_4:.*]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<128x128xf32>} : () -> tensor<128x128xf32> -// CHECK-DAG: %[[VAL_5:.*]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<128xf32>} : () -> tensor<128xf32> +// CHECK-DAG: %[[VAL_4:.*]] = "tfl.pseudo_const"() <{value = dense<1.000000e+00> : tensor<128x128xf32>}> : () -> tensor<128x128xf32> +// CHECK-DAG: %[[VAL_5:.*]] = "tfl.pseudo_const"() <{value = dense<1.000000e+00> : tensor<128xf32>}> : () -> tensor<128xf32> // CHECK: %[[VAL_6:.*]] = call @func_2_CPU_FLOAT(%[[VAL_3]]#1, %[[VAL_4]], %[[VAL_5]]) {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_2"} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<128x128xf32> -// CHECK: %[[VAL_7:.*]] = "tfl.pseudo_const"() {value = dense<[1, 128, 128]> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK: %[[VAL_7:.*]] = "tfl.pseudo_const"() <{value = dense<[1, 128, 128]> : tensor<3xi32>}> : () -> tensor<3xi32> // CHECK: %[[VAL_8:.*]] = call @func_1_GPU_FLOAT(%[[VAL_6]], %[[VAL_7]], %[[VAL_3]]#0) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_1"} : (tensor<128x128xf32>, tensor<3xi32>, tensor<1x128x128xf32>) -> tensor<1x128x128xf32> // CHECK: return %[[VAL_8]] : tensor<1x128x128xf32> // CHECK: } // CHECK: func.func private @func_0_GPU_FLOAT(%[[VAL_0:.*]]: tensor<1x128x128xf32>, %[[VAL_1:.*]]: tensor<128xf32>, %[[VAL_2:.*]]: tensor<2xi32>) -> (tensor<1x128x128xf32>, tensor<128x128xf32>) attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { -// CHECK: %[[VAL_3:.*]] = tfl.add(%[[VAL_0]], %[[VAL_1]]) {fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> +// CHECK: %[[VAL_3:.*]] = tfl.add(%[[VAL_0]], %[[VAL_1]]) <{fused_activation_function = "NONE"}> {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> // CHECK: %[[VAL_4:.*]] = "tfl.reshape"(%[[VAL_3]], %[[VAL_2]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x128x128xf32>, tensor<2xi32>) -> tensor<128x128xf32> // CHECK: %[[VAL_5:.*]] = "tfl.relu"(%[[VAL_4]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<128x128xf32>) -> tensor<128x128xf32> // CHECK: return %[[VAL_3]], %[[VAL_5]] : tensor<1x128x128xf32>, tensor<128x128xf32> @@ -229,7 +229,7 @@ func.func @norm2(%arg0: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> { // CHECK: } // CHECK: func.func private @func_2_CPU_FLOAT(%[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<128x128xf32>, %[[VAL_2:.*]]: tensor<128xf32>) -> tensor<128x128xf32> attributes {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_2"} { -// CHECK: %[[VAL_3:.*]] = "tfl.fully_connected"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {fused_activation_function = "NONE", keep_num_dims = false, tac.device = "CPU", tac.inference_type = "FLOAT", weights_format = "DEFAULT"} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<128x128xf32> +// CHECK: %[[VAL_3:.*]] = "tfl.fully_connected"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<128x128xf32> // CHECK: return %[[VAL_3]] : tensor<128x128xf32> // CHECK: } @@ -248,8 +248,8 @@ func.func @quantizedOpOnly(%arg0: tensor<1x!quant.uniform>, } // CHECK: func @quantizedOpOnly(%[[VAL_0:.*]]: tensor<1x!quant.uniform>, %[[VAL_1:.*]]: tensor<1x!quant.uniform>) -> tensor<2x1x!quant.uniform> { -// CHECK: %[[VAL_2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x!quant.uniform>, value = dense<127> : tensor<1xi8>} : () -> tensor<1x!quant.uniform> -// CHECK: %[[VAL_3:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x!quant.uniform>, value = dense<127> : tensor<1xi8>} : () -> tensor<1x!quant.uniform> +// CHECK: %[[VAL_2:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x!quant.uniform>, value = dense<127> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform> +// CHECK: %[[VAL_3:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x!quant.uniform>, value = dense<127> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform> // CHECK: %[[VAL_4:.*]] = call @func_0_CPU_QUANTIZED_INT8(%[[VAL_0]], %[[VAL_2]], %[[VAL_3]], %[[VAL_1]]) {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8", tac.interface_name = "func_0"} : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>, tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<2x1x!quant.uniform> // CHECK: return %[[VAL_4]] : tensor<2x1x!quant.uniform> // CHECK: } @@ -258,7 +258,7 @@ func.func @quantizedOpOnly(%arg0: tensor<1x!quant.uniform>, // CHECK: %[[VAL_5:.*]] = tfl.mul %[[VAL_0]], %[[VAL_1]] {fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : tensor<1x!quant.uniform> // CHECK: %[[VAL_6:.*]] = tfl.add %[[VAL_5]], %[[VAL_2]] {fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : tensor<1x!quant.uniform> // CHECK: %[[VAL_7:.*]] = tfl.add %[[VAL_3]], %[[VAL_1]] {fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : tensor<1x!quant.uniform> -// CHECK: %[[VAL_8:.*]] = "tfl.pack"(%[[VAL_6]], %[[VAL_7]]) {axis = 0 : i32, tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8", values_count = 2 : i32} : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<2x1x!quant.uniform> +// CHECK: %[[VAL_8:.*]] = "tfl.pack"(%[[VAL_6]], %[[VAL_7]]) <{axis = 0 : i32, values_count = 2 : i32}> {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<2x1x!quant.uniform> // CHECK: return %[[VAL_8]] : tensor<2x1x!quant.uniform> // CHECK: } @@ -280,12 +280,12 @@ func.func @quantizationWithFloat(%arg0: tensor<1x1x384x!quant.uniform>, %[[VAL_1:.*]]: tensor<1x1x384x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> { -// CHECK: %[[VAL_2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x384x1x!quant.uniform>, value = dense<127> : tensor<1x384x1xi8>} : () -> tensor<1x384x1x!quant.uniform> +// CHECK: %[[VAL_2:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x384x1x!quant.uniform>, value = dense<127> : tensor<1x384x1xi8>}> : () -> tensor<1x384x1x!quant.uniform> // CHECK: %[[VAL_3:.*]] = call @func_1_CPU_QUANTIZED_INT8(%[[VAL_0]], %[[VAL_2]]) {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8", tac.interface_name = "func_1"} : (tensor<1x1x384x!quant.uniform>, tensor<1x384x1x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> // CHECK: %[[VAL_4:.*]] = "tfl.dequantize"(%[[VAL_3]]) : (tensor<1x384x384x!quant.uniform>) -> tensor<1x384x384xf32> -// CHECK: %[[VAL_5:.*]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x1x384xf32>} : () -> tensor<1x384x384xf32> +// CHECK: %[[VAL_5:.*]] = "tfl.pseudo_const"() <{value = dense<1.000000e+00> : tensor<1x1x384xf32>}> : () -> tensor<1x384x384xf32> // CHECK: %[[VAL_6:.*]] = call @func_0_GPU_FLOAT(%[[VAL_4]], %[[VAL_5]]) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} : (tensor<1x384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32> -// CHECK: %[[VAL_7:.*]] = "tfl.quantize"(%[[VAL_6]]) {qtype = tensor<1x384x1x!quant.uniform>} : (tensor<1x384x384xf32>) -> tensor<1x384x384x!quant.uniform> +// CHECK: %[[VAL_7:.*]] = "tfl.quantize"(%[[VAL_6]]) <{qtype = tensor<1x384x1x!quant.uniform>}> : (tensor<1x384x384xf32>) -> tensor<1x384x384x!quant.uniform> // CHECK: %[[VAL_8:.*]] = call @func_2_CPU_QUANTIZED_INT8(%[[VAL_1]], %[[VAL_7]]) {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8", tac.interface_name = "func_2"} : (tensor<1x1x384x!quant.uniform>, tensor<1x384x384x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> // CHECK: return %[[VAL_8]] : tensor<1x384x384x!quant.uniform> // CHECK: } @@ -296,12 +296,12 @@ func.func @quantizationWithFloat(%arg0: tensor<1x1x384x!quant.uniform>, %[[VAL_1:.*]]: tensor<1x384x1x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> attributes {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8", tac.interface_name = "func_1"} { -// CHECK: %[[VAL_2:.*]] = tfl.mul(%[[VAL_0]], %[[VAL_1]]) {fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : (tensor<1x1x384x!quant.uniform>, tensor<1x384x1x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> +// CHECK: %[[VAL_2:.*]] = tfl.mul(%[[VAL_0]], %[[VAL_1]]) <{fused_activation_function = "NONE"}> {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : (tensor<1x1x384x!quant.uniform>, tensor<1x384x1x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> // CHECK: return %[[VAL_2]] : tensor<1x384x384x!quant.uniform> // CHECK: } // CHECK: func private @func_2_CPU_QUANTIZED_INT8(%[[VAL_0:.*]]: tensor<1x1x384x!quant.uniform>, %[[VAL_1:.*]]: tensor<1x384x384x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> attributes {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8", tac.interface_name = "func_2"} { -// CHECK: %[[VAL_2:.*]] = tfl.mul(%[[VAL_0]], %[[VAL_1]]) {fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : (tensor<1x1x384x!quant.uniform>, tensor<1x384x384x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> +// CHECK: %[[VAL_2:.*]] = tfl.mul(%[[VAL_0]], %[[VAL_1]]) <{fused_activation_function = "NONE"}> {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : (tensor<1x1x384x!quant.uniform>, tensor<1x384x384x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> // CHECK: return %[[VAL_2]] : tensor<1x384x384x!quant.uniform> // CHECK: } @@ -360,28 +360,28 @@ func.func @cond_false_72730(%arg0: tensor, %arg1: tensor, tensor) -> tensor %31 = tfl.add %30, %25 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor %32 = "tfl.reshape"(%31, %cst_8) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor - %33 = "tfl.gather"(%28, %32) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + %33 = "tfl.gather"(%28, %32) <{axis = 0 : i32, batch_dims = 0 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor %34 = "tfl.reshape"(%33, %cst_8) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor %35 = "tfl.shape"(%34) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<1xi32> %36 = "tfl.fill"(%35, %cst_7) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor %37 = "tfl.expand_dims"(%arg10, %cst_9) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor %38 = tfl.add %37, %15 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor %39 = "tfl.reshape"(%38, %cst_8) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor - %40 = "tfl.gather"(%18, %39) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + %40 = "tfl.gather"(%18, %39) <{axis = 0 : i32, batch_dims = 0 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor %41 = "tfl.reshape"(%40, %cst_8) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor %42 = "tfl.shape"(%41) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<1xi32> %43 = "tfl.fill"(%42, %cst_7) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor - %44 = "tfl.gather"(%arg2, %arg8) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + %44 = "tfl.gather"(%arg2, %arg8) <{axis = 0 : i32, batch_dims = 0 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor %45 = "tfl.equal"(%44, %cst_6) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor %46 = "tfl.custom"(%45, %36, %34) {custom_code = "FlexSelect", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor %47 = "tfl.custom"(%arg11, %arg8, %46) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> %48 = "tfl.equal"(%44, %cst_11) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor %49 = "tfl.custom"(%48, %43, %41) {custom_code = "FlexSelect", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor %50 = "tfl.custom"(%arg12, %arg8, %49) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> - %51 = "tfl.gather"(%cst_5, %44) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<5xi32>, tensor) -> tensor + %51 = "tfl.gather"(%cst_5, %44) <{axis = 0 : i32, batch_dims = 0 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<5xi32>, tensor) -> tensor %52 = tfl.add %arg9, %51 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor %53 = "tfl.custom"(%arg13, %arg8, %52) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> - %54 = "tfl.gather"(%cst_4, %44) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<5xi32>, tensor) -> tensor + %54 = "tfl.gather"(%cst_4, %44) <{axis = 0 : i32, batch_dims = 0 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<5xi32>, tensor) -> tensor %55 = tfl.add %arg10, %54 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor %56 = "tfl.custom"(%arg14, %arg8, %55) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> %57 = tfl.add %arg7, %cst_11 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor @@ -406,12 +406,12 @@ func.func @cond_false_72730(%arg0: tensor, %arg1: tensor, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>) -> tensor attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { // CHECK: %0 = "tfl.shape"(%arg0) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> -// CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg2) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor // CHECK: return %1 : tensor // CHECK: } // CHECK: func.func private @func_1_CPU_FLOAT(%arg0: tensor<1xi32>, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor<2xi32>) -> (tensor, tensor) attributes {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_1"} { -// CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "FlexTensorListReserve", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor>> -// CHECK: %1 = "tfl.custom"(%arg0, %arg1) {custom_code = "FlexTensorListReserve", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor>> +// CHECK: %0 = "tfl.custom"(%arg0, %arg1) <{custom_code = "FlexTensorListReserve", custom_option = #tfl}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor>> +// CHECK: %1 = "tfl.custom"(%arg0, %arg1) <{custom_code = "FlexTensorListReserve", custom_option = #tfl}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor>> // CHECK: %2:8 = "tfl.while"(%arg2, %arg2, %arg3, %arg4, %0, %0, %1, %1) ({ // CHECK: ^bb0(%arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor>>, %arg14: tensor>>, %arg15: tensor>>, %arg16: tensor>>): // CHECK: %7 = func.call @func_2_DARWINN_FLOAT(%arg10, %arg1, %arg9) {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_2"} : (tensor, tensor, tensor) -> tensor @@ -430,24 +430,24 @@ func.func @cond_false_72730(%arg0: tensor, %arg1: tensor : tensor<1xi32> // CHECK: %cst_9 = arith.constant dense<0> : tensor<1xi32> // CHECK: %7:2 = func.call @func_3_DARWINN_FLOAT(%arg5, %cst_9, %cst_8, %cst_7, %cst_6, %cst_5) {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_3"} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor, tensor, tensor<1xi32>) -> (tensor, tensor<2xi32>) -// CHECK: %8 = "tfl.reduce_prod"(%7#1, %cst_9) {keep_dims = true, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>) -> tensor<1xi32> +// CHECK: %8 = "tfl.reduce_prod"(%7#1, %cst_9) <{keep_dims = true}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: %9:3 = func.call @func_4_DARWINN_FLOAT(%arg5, %8, %arg6, %cst_9, %cst_8, %cst_7, %cst_6, %cst_5) {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_4"} : (tensor, tensor<1xi32>, tensor, tensor<1xi32>, tensor<1xi32>, tensor, tensor, tensor<1xi32>) -> (tensor, tensor, tensor<2xi32>) -// CHECK: %10 = "tfl.reduce_prod"(%9#2, %cst_9) {keep_dims = true, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>) -> tensor<1xi32> +// CHECK: %10 = "tfl.reduce_prod"(%9#2, %cst_9) <{keep_dims = true}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: %11 = "tfl.expand_dims"(%arg11, %cst_4) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor // CHECK: %12 = "tfl.expand_dims"(%arg12, %cst_4) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor // CHECK: %13:10 = func.call @func_5_DARWINN_FLOAT(%arg6, %10, %arg10, %cst_6, %11, %9#1, %cst_3, %cst_2, %12, %7#0, %9#0, %arg7, %cst_1, %cst_0, %arg11, %cst, %arg12, %arg9) {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_5"} : (tensor, tensor<1xi32>, tensor, tensor, tensor, tensor, tensor<1xi32>, tensor, tensor, tensor, tensor, tensor, tensor, tensor<5xi32>, tensor, tensor<5xi32>, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -// CHECK: %14 = "tfl.custom"(%13#5, %13#2, %13#1) {custom_code = "FlexSelect", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor -// CHECK: %15 = "tfl.custom"(%arg13, %arg10, %14) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> -// CHECK: %16 = "tfl.custom"(%13#6, %13#4, %13#3) {custom_code = "FlexSelect", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor -// CHECK: %17 = "tfl.custom"(%arg14, %arg10, %16) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> -// CHECK: %18 = "tfl.custom"(%arg15, %arg10, %13#7) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> -// CHECK: %19 = "tfl.custom"(%arg16, %arg10, %13#8) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> +// CHECK: %14 = "tfl.custom"(%13#5, %13#2, %13#1) <{custom_code = "FlexSelect", custom_option = #tfl}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor +// CHECK: %15 = "tfl.custom"(%arg13, %arg10, %14) <{custom_code = "FlexTensorListSetItem", custom_option = #tfl}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> +// CHECK: %16 = "tfl.custom"(%13#6, %13#4, %13#3) <{custom_code = "FlexSelect", custom_option = #tfl}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor +// CHECK: %17 = "tfl.custom"(%arg14, %arg10, %16) <{custom_code = "FlexTensorListSetItem", custom_option = #tfl}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> +// CHECK: %18 = "tfl.custom"(%arg15, %arg10, %13#7) <{custom_code = "FlexTensorListSetItem", custom_option = #tfl}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> +// CHECK: %19 = "tfl.custom"(%arg16, %arg10, %13#8) <{custom_code = "FlexTensorListSetItem", custom_option = #tfl}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> // CHECK: "tfl.yield"(%13#9, %13#0, %13#7, %13#8, %15, %17, %18, %19) : (tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) -> () // CHECK: }) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) -> (tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) -// CHECK: %3 = "tfl.custom"(%2#4, %arg0) {custom_code = "FlexTensorListStack", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor<1xi32>) -> tensor -// CHECK: %4 = "tfl.custom"(%3, %arg8) {custom_code = "FlexTranspose", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor -// CHECK: %5 = "tfl.custom"(%2#5, %arg0) {custom_code = "FlexTensorListStack", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor<1xi32>) -> tensor -// CHECK: %6 = "tfl.custom"(%5, %arg8) {custom_code = "FlexTranspose", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor +// CHECK: %3 = "tfl.custom"(%2#4, %arg0) <{custom_code = "FlexTensorListStack", custom_option = #tfl}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor<1xi32>) -> tensor +// CHECK: %4 = "tfl.custom"(%3, %arg8) <{custom_code = "FlexTranspose", custom_option = #tfl}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor +// CHECK: %5 = "tfl.custom"(%2#5, %arg0) <{custom_code = "FlexTensorListStack", custom_option = #tfl}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor<1xi32>) -> tensor +// CHECK: %6 = "tfl.custom"(%5, %arg8) <{custom_code = "FlexTranspose", custom_option = #tfl}> {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor // CHECK: return %4, %6 : tensor, tensor // CHECK: } // CHECK: func.func private @func_2_DARWINN_FLOAT(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_2"} { @@ -458,25 +458,25 @@ func.func @cond_false_72730(%arg0: tensor, %arg1: tensor, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor, %arg4: tensor, %arg5: tensor<1xi32>) -> (tensor, tensor<2xi32>) attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_3"} { // CHECK: %0 = "tfl.shape"(%arg0) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> -// CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg2) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor // CHECK: %2 = "tfl.range"(%arg3, %1, %arg4) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor -// CHECK: %3 = "tfl.pack"(%1, %arg4) {axis = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT", values_count = 2 : i32} : (tensor, tensor) -> tensor<2xi32> -// CHECK: %4 = "tfl.strided_slice"(%0, %arg2, %arg5, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor -// CHECK: %5 = tfl.mul(%2, %4) {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: %3 = "tfl.pack"(%1, %arg4) <{axis = 0 : i32, values_count = 2 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor<2xi32> +// CHECK: %4 = "tfl.strided_slice"(%0, %arg2, %arg5, %arg2) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %5 = tfl.mul(%2, %4) <{fused_activation_function = "NONE"}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor // CHECK: %6 = "tfl.reshape"(%5, %3) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor -// CHECK: %7 = "tfl.strided_slice"(%0, %arg1, %arg5, %arg2) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK: %7 = "tfl.strided_slice"(%0, %arg1, %arg5, %arg2) <{begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> // CHECK: return %6, %7 : tensor, tensor<2xi32> // CHECK: } // CHECK: func.func private @func_4_DARWINN_FLOAT(%arg0: tensor, %arg1: tensor<1xi32>, %arg2: tensor, %arg3: tensor<1xi32>, %arg4: tensor<1xi32>, %arg5: tensor, %arg6: tensor, %arg7: tensor<1xi32>) -> (tensor, tensor, tensor<2xi32>) attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_4"} { // CHECK: %0 = "tfl.reshape"(%arg0, %arg1) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor // CHECK: %1 = "tfl.shape"(%arg2) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> -// CHECK: %2 = "tfl.strided_slice"(%1, %arg3, %arg4, %arg4) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %2 = "tfl.strided_slice"(%1, %arg3, %arg4, %arg4) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor // CHECK: %3 = "tfl.range"(%arg5, %2, %arg6) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor -// CHECK: %4 = "tfl.pack"(%2, %arg6) {axis = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT", values_count = 2 : i32} : (tensor, tensor) -> tensor<2xi32> -// CHECK: %5 = "tfl.strided_slice"(%1, %arg4, %arg7, %arg4) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor -// CHECK: %6 = tfl.mul(%3, %5) {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: %4 = "tfl.pack"(%2, %arg6) <{axis = 0 : i32, values_count = 2 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor<2xi32> +// CHECK: %5 = "tfl.strided_slice"(%1, %arg4, %arg7, %arg4) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %6 = tfl.mul(%3, %5) <{fused_activation_function = "NONE"}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor // CHECK: %7 = "tfl.reshape"(%6, %4) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor -// CHECK: %8 = "tfl.strided_slice"(%1, %arg3, %arg7, %arg4) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK: %8 = "tfl.strided_slice"(%1, %arg3, %arg7, %arg4) <{begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> // CHECK: return %0, %7, %8 : tensor, tensor, tensor<2xi32> // CHECK: } // CHECK: func.func private @func_5_DARWINN_FLOAT(%arg0: tensor, %arg1: tensor<1xi32>, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor<1xi32>, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor<5xi32>, %arg14: tensor, %arg15: tensor<5xi32>, %arg16: tensor, %arg17: tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_5"} { @@ -484,22 +484,22 @@ func.func @cond_false_72730(%arg0: tensor, %arg1: tensor // CHECK: %2 = tfl.add %arg4, %arg5 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor // CHECK: %3 = "tfl.reshape"(%2, %arg6) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor -// CHECK: %4 = "tfl.gather"(%0, %3) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: %4 = "tfl.gather"(%0, %3) <{axis = 0 : i32, batch_dims = 0 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor // CHECK: %5 = "tfl.reshape"(%4, %arg6) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor // CHECK: %6 = "tfl.shape"(%5) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<1xi32> // CHECK: %7 = "tfl.fill"(%6, %arg7) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor // CHECK: %8 = tfl.add %arg8, %arg9 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor // CHECK: %9 = "tfl.reshape"(%8, %arg6) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor -// CHECK: %10 = "tfl.gather"(%arg10, %9) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: %10 = "tfl.gather"(%arg10, %9) <{axis = 0 : i32, batch_dims = 0 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor // CHECK: %11 = "tfl.reshape"(%10, %arg6) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor // CHECK: %12 = "tfl.shape"(%11) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<1xi32> // CHECK: %13 = "tfl.fill"(%12, %arg7) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor -// CHECK: %14 = "tfl.gather"(%arg11, %arg2) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: %14 = "tfl.gather"(%arg11, %arg2) <{axis = 0 : i32, batch_dims = 0 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor // CHECK: %15 = "tfl.equal"(%14, %arg12) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor // CHECK: %16 = "tfl.equal"(%14, %arg3) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor -// CHECK: %17 = "tfl.gather"(%arg13, %14) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<5xi32>, tensor) -> tensor +// CHECK: %17 = "tfl.gather"(%arg13, %14) <{axis = 0 : i32, batch_dims = 0 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<5xi32>, tensor) -> tensor // CHECK: %18 = tfl.add %arg14, %17 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor -// CHECK: %19 = "tfl.gather"(%arg15, %14) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<5xi32>, tensor) -> tensor +// CHECK: %19 = "tfl.gather"(%arg15, %14) <{axis = 0 : i32, batch_dims = 0 : i32}> {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<5xi32>, tensor) -> tensor // CHECK: %20 = tfl.add %arg16, %19 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor // CHECK: %21 = tfl.add %arg17, %arg3 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor // CHECK: return %1, %5, %7, %11, %13, %15, %16, %18, %20, %21 : tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc index 4efdd053eec5c2..701d9cad1c34c1 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc @@ -59,13 +59,14 @@ int64_t GetTransferredTensorBytes(func::CallOp from_graph, for (auto input : to_graph.getOperands()) { Operation* input_op = input.getDefiningOp(); if (input_op && input_op == from_graph.getOperation()) { - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = + mlir::dyn_cast_or_null(input.getType()); if (input_type == nullptr || !input_type.hasStaticShape()) continue; // Quantized type does not support getSizeInBits. if (IsQUI8Type(input_type) || IsQI8Type(input_type)) { total_size_transferred += input_type.getNumElements() * 8; } else { - auto s_type = input_type.cast(); + auto s_type = mlir::cast(input_type); total_size_transferred += s_type.getNumElements() * s_type.getElementTypeBitWidth(); } @@ -81,7 +82,8 @@ int64_t GetTransferredElementCount(func::CallOp from_graph, for (auto input : to_graph.getOperands()) { Operation* input_op = input.getDefiningOp(); if (input_op && input_op == from_graph.getOperation()) { - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = + mlir::dyn_cast_or_null(input.getType()); if (input_type == nullptr || !input_type.hasStaticShape()) continue; total_element_count += input_type.getNumElements(); } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc index ea1c299fd546c1..0c37a8da20575f 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc @@ -156,13 +156,13 @@ struct FoldQuantizedI32ToFloat : public OpRewritePattern { if (!IsQI32Type(input_dequant.getType())) return failure(); auto output_type = - dequant_op.getOutput().getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(dequant_op.getOutput().getType()); if (!output_type || !output_type.getElementType().isF32()) return failure(); - auto input_type = input_dequant.getType().dyn_cast(); + auto input_type = mlir::dyn_cast(input_dequant.getType()); // TODO(renjieliu): support UniformQuantizedPerAxisType. - auto q_type = input_type.getElementType() - .dyn_cast_or_null(); + auto q_type = mlir::dyn_cast_or_null( + input_type.getElementType()); if (!q_type) return failure(); const float scale = q_type.getScale(); @@ -183,9 +183,9 @@ struct FoldQuantizedI32ToFloat : public OpRewritePattern { }; auto dequant_values = - input_values.cast().mapValues( - FloatType::getF32(rewriter.getContext()), - llvm::function_ref(dequantize_func)); + mlir::cast(input_values) + .mapValues(FloatType::getF32(rewriter.getContext()), + llvm::function_ref(dequantize_func)); rewriter.replaceOpWithNewOp(dequant_op, dequant_op.getType(), dequant_values); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc index baf25aa54c109b..278c54e8805f3d 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc @@ -96,11 +96,11 @@ LogicalResult EnsureBias(Operation* op, int bias_idx, PatternRewriter& rewriter) { auto bias = op->getOperand(bias_idx); - if (!bias.getType().isa()) return failure(); + if (!mlir::isa(bias.getType())) return failure(); // Proceed to create a zero bias. auto output = op->getResult(0); - auto output_type = output.getType().dyn_cast_or_null(); + auto output_type = mlir::dyn_cast_or_null(output.getType()); if (!output_type) return failure(); // bias should be a vector sized of the last output dim. @@ -163,7 +163,7 @@ SmallVector SliceOutputs(Operation* split_op, Value input, SmallVector slice_size; auto current_output = split_op->getResult(i); auto current_output_type = - current_output.getType().cast(); + mlir::cast(current_output.getType()); for (int d = 0; d < input_type.getRank(); ++d) { if (d == split_dim) { // Split dimension. @@ -208,7 +208,7 @@ LogicalResult LowerPackIntoConcatReshape::matchAndRewrite( TFL::PackOp pack_op, PatternRewriter& rewriter) const { // Pack op should have same shape type. SmallVector pack_inputs(pack_op.getValues()); - auto input_type = pack_inputs[0].getType().dyn_cast(); + auto input_type = mlir::dyn_cast(pack_inputs[0].getType()); if (!input_type) return failure(); // Figure out output shapes. @@ -266,8 +266,8 @@ LogicalResult SquaredDifference::matchAndRewrite( TFL::SquaredDifferenceOp squared_diff_op, PatternRewriter& rewriter) const { auto x = squared_diff_op.getLhs(); auto y = squared_diff_op.getRhs(); - auto x_type = x.getType().dyn_cast(); - auto y_type = y.getType().dyn_cast(); + auto x_type = mlir::dyn_cast(x.getType()); + auto y_type = mlir::dyn_cast(y.getType()); if (!x_type || !y_type) return failure(); if (x_type.getShape() != y_type.getShape()) return failure(); @@ -290,16 +290,16 @@ LogicalResult UnrollSplit::matchAndRewrite(TFL::SplitOp split_op, PatternRewriter& rewriter) const { auto num_splits = split_op.getNumSplits(); auto input = split_op.getValue(); - auto input_type = input.getType().dyn_cast(); + auto input_type = mlir::dyn_cast(input.getType()); if (input_type == nullptr || !input_type.hasStaticShape()) return failure(); for (auto result : split_op.getResults()) { - auto result_type = result.getType().dyn_cast(); + auto result_type = mlir::dyn_cast(result.getType()); if (result_type == nullptr) return failure(); } auto output = split_op.getResult(0); - auto output_type = output.getType().cast(); + auto output_type = mlir::cast(output.getType()); // TODO(renjieliu): change to use split_dim when we raise the constants // as well. @@ -330,11 +330,11 @@ LogicalResult UnrollSplitV::matchAndRewrite(TFL::SplitVOp splitv_op, return failure(); auto input = splitv_op.getValue(); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type || !input_type.hasRank()) return failure(); for (auto result : splitv_op.getResults()) { - auto result_type = result.getType().dyn_cast(); + auto result_type = mlir::dyn_cast(result.getType()); if (result_type == nullptr) return failure(); } @@ -371,20 +371,21 @@ LogicalResult PadSlice::matchAndRewrite(TFL::SliceOp slice_op, // We have to know the shape of the input, as well as the begin/size. // also, begin and size have to be constants. auto input = slice_op.getInput(); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type || !input_type.hasStaticShape()) return failure(); if (input_type.getRank() >= 4) return failure(); auto begin = slice_op.getBegin(); - auto begin_type = begin.getType().dyn_cast_or_null(); + auto begin_type = mlir::dyn_cast_or_null(begin.getType()); if (!begin_type || !begin_type.hasStaticShape()) return failure(); auto size = slice_op.getSize(); - auto size_type = size.getType().dyn_cast_or_null(); + auto size_type = mlir::dyn_cast_or_null(size.getType()); if (!size_type || !size_type.hasStaticShape()) return failure(); - auto output_type = slice_op.getType().dyn_cast_or_null(); + auto output_type = + mlir::dyn_cast_or_null(slice_op.getType()); if (!output_type || !output_type.hasStaticShape()) return failure(); // Pad 0s in front of the begin. @@ -472,17 +473,17 @@ LogicalResult FullyConnectedToConv::matchAndRewrite( TFL::FullyConnectedOp fc_op, PatternRewriter& rewriter) const { // We have to know the shape of the input. auto input = fc_op.getInput(); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type || !input_type.hasStaticShape()) return failure(); // We have to know the shape of the weight. auto weight = fc_op.getFilter(); - auto weight_type = weight.getType().dyn_cast_or_null(); + auto weight_type = mlir::dyn_cast_or_null(weight.getType()); if (!weight_type || !weight_type.hasStaticShape()) return failure(); // We have to know the shape of the output as well. auto output = fc_op.getResult(0); - auto output_type = output.getType().dyn_cast_or_null(); + auto output_type = mlir::dyn_cast_or_null(output.getType()); if (!output_type || !output_type.hasStaticShape()) return failure(); // Insert a reshape after the input. @@ -532,13 +533,14 @@ LogicalResult PadConcat::matchAndRewrite(TFL::ConcatenationOp concat_op, PatternRewriter& rewriter) const { int rank = -1; for (auto input : concat_op.getValues()) { - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type || !input_type.hasStaticShape()) return failure(); rank = input_type.getRank(); } - auto output_type = concat_op.getType().dyn_cast_or_null(); + auto output_type = + mlir::dyn_cast_or_null(concat_op.getType()); if (!output_type || !output_type.hasStaticShape()) return failure(); if (rank >= 4) return failure(); @@ -547,7 +549,7 @@ LogicalResult PadConcat::matchAndRewrite(TFL::ConcatenationOp concat_op, // We will insert a reshape op after every input. SmallVector reshape_ops; for (auto input : concat_op.getValues()) { - auto input_type = input.getType().cast(); + auto input_type = mlir::cast(input.getType()); // Get the new shape. SmallVector new_shape; for (int i = 0; i < 4 - rank; ++i) { @@ -603,7 +605,7 @@ LogicalResult PadConcat::matchAndRewrite(TFL::ConcatenationOp concat_op, LogicalResult ReduceMeanToAvgPool::matchAndRewrite( TFL::MeanOp mean_op, PatternRewriter& rewriter) const { auto input = mean_op.getInput(); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); // Only 4d is supported here. if (!input_type || input_type.getRank() != 4) return failure(); @@ -619,7 +621,7 @@ LogicalResult ReduceMeanToAvgPool::matchAndRewrite( } auto output = mean_op.getOutput(); - auto output_type = output.getType().dyn_cast_or_null(); + auto output_type = mlir::dyn_cast_or_null(output.getType()); if (!output_type) return failure(); auto input_quantized_type = @@ -669,7 +671,7 @@ LogicalResult ReduceMeanToAvgPool::matchAndRewrite( LogicalResult InsertRequantForReduceMean::matchAndRewrite( TFL::MeanOp mean_op, PatternRewriter& rewriter) const { auto input = mean_op.getInput(); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type) return failure(); // Only need to do this for quantized input. @@ -678,7 +680,7 @@ LogicalResult InsertRequantForReduceMean::matchAndRewrite( if (!input_quantized_type) return failure(); auto output = mean_op.getOutput(); - auto output_type = output.getType().dyn_cast_or_null(); + auto output_type = mlir::dyn_cast_or_null(output.getType()); if (!output_type) return failure(); auto output_quantized_type = quant::QuantizedType::getQuantizedElementType(output_type); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/fold_constants_to_subgraph.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/fold_constants_to_subgraph.cc index b6c544a8f69c9b..e4985f2b5700d5 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/fold_constants_to_subgraph.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/fold_constants_to_subgraph.cc @@ -107,11 +107,12 @@ bool IsConstOrQConstInt(Operation* op) { if (auto arith_const_op = dyn_cast_or_null(op)) { // arith ConstOp path. - auto type = arith_const_op.getType().cast().getElementType(); + auto type = + mlir::cast(arith_const_op.getType()).getElementType(); if (!type.isInteger(32) && !type.isInteger(64)) return false; } else if (auto const_op = dyn_cast_or_null(op)) { // ConstOp path. - auto type = const_op.getType().cast().getElementType(); + auto type = mlir::cast(const_op.getType()).getElementType(); if (!type.isInteger(32) && !type.isInteger(64)) return false; } else { // QConstOp path. diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc index 4fd9f945764b3a..1ff585f6c71cb6 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc @@ -113,18 +113,11 @@ void AddAttrs(OpsAdded& ops_added, OpBuilder& builder, int func_count) { added_func_op->setAttr(kInterfaceNameAttr, interface_name); added_call_op->setAttr(kInterfaceNameAttr, interface_name); - StringAttr device = added_func_op->getRegion(0) - .getBlocks() - .front() - .front() - .getAttr(kDevice) - .cast(); - StringAttr inference_type = added_func_op->getRegion(0) - .getBlocks() - .front() - .front() - .getAttr(kInferenceType) - .cast(); + StringAttr device = mlir::cast( + added_func_op->getRegion(0).getBlocks().front().front().getAttr(kDevice)); + StringAttr inference_type = mlir::cast( + added_func_op->getRegion(0).getBlocks().front().front().getAttr( + kInferenceType)); added_call_op->setAttr(kDevice, device); added_call_op->setAttr(kInferenceType, inference_type); added_func_op->setAttr(kDevice, device); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc index a2f7441cc170b1..05cadcbb26b1e7 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc @@ -110,7 +110,7 @@ void ApplyTacFilter(ModuleOp module, const TacFilter& tac_filter, llvm::Regex op_regex(tac_filter.op_filter().op_name_pattern()); module.walk([&](Operation* op) { - auto named_loc = op->getLoc().dyn_cast(); + auto named_loc = mlir::dyn_cast(op->getLoc()); if (!named_loc) { return; } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 4655fa7c069b54..0b7bd8cc7177e6 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -38,6 +39,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -83,6 +85,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/low_bit_utils.h" #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" @@ -108,7 +111,6 @@ limitations under the License. #include "tensorflow/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/graph_info.h" #include "tensorflow/lite/python/metrics/converter_error_data.pb.h" -#include "tensorflow/lite/schema/mutable/schema_generated.h" #include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/toco/toco_flags.pb.h" @@ -161,6 +163,9 @@ ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex"; // used by the TOCO export. (It does not explain rationale for this choice.) constexpr size_t kInitialBufferSize = 10240; +// Flatbuffer fields to be padded to 16 bytes aligned. +constexpr size_t kFbAlignment = 16; + // Set `isSigned` to false if the `type` is an 8-bit unsigned integer type. // Since tflite doesn't support unsigned for other types, returns error if // `isSigned` is set to false for other types. @@ -185,11 +190,11 @@ static StatusOr GetTFLiteType(Type type, return tflite::TensorType_BFLOAT16; } else if (type.isF64()) { return tflite::TensorType_FLOAT64; - } else if (type.isa()) { + } else if (mlir::isa(type)) { return tflite::TensorType_STRING; - } else if (type.isa()) { + } else if (mlir::isa(type)) { return tflite::TensorType_UINT8; - } else if (auto complex_type = type.dyn_cast()) { + } else if (auto complex_type = mlir::dyn_cast(type)) { auto ftype = complex_type.getElementType(); if (ftype.isF32()) { return tflite::TensorType_COMPLEX64; @@ -198,7 +203,7 @@ static StatusOr GetTFLiteType(Type type, return tflite::TensorType_COMPLEX128; } return Status(absl::StatusCode::kInvalidArgument, "Unsupported type"); - } else if (auto itype = type.dyn_cast()) { + } else if (auto itype = mlir::dyn_cast(type)) { switch (itype.getWidth()) { case 1: return tflite::TensorType_BOOL; @@ -223,19 +228,20 @@ static StatusOr GetTFLiteType(Type type, : tflite::TensorType_INT64; } } else if (auto q_uniform_type = - type.dyn_cast()) { + mlir::dyn_cast(type)) { return GetTFLiteType(q_uniform_type.getStorageType(), q_uniform_type.isSigned()); } else if (auto q_peraxis_type = - type.dyn_cast()) { + mlir::dyn_cast( + type)) { return GetTFLiteType(q_peraxis_type.getStorageType(), q_peraxis_type.isSigned()); } else if (auto q_calibrated_type = - type.dyn_cast()) { + mlir::dyn_cast(type)) { return GetTFLiteType(q_calibrated_type.getExpressedType()); - } else if (type.isa()) { + } else if (mlir::isa(type)) { return tflite::TensorType_RESOURCE; - } else if (type.isa()) { + } else if (mlir::isa(type)) { return tflite::TensorType_VARIANT; } // TFLite export fills FLOAT32 for unknown data types. Returning an error @@ -253,13 +259,13 @@ static bool IsConst(Operation* op) { static bool IsTFResourceOp(Operation* op) { for (const auto& operand : op->getOperands()) { auto elementType = getElementTypeOrSelf(operand.getType()); - if (elementType.isa()) { + if (mlir::isa(elementType)) { return true; } } for (const auto& result : op->getResults()) { auto elementType = getElementTypeOrSelf(result.getType()); - if (elementType.isa()) { + if (mlir::isa(elementType)) { return true; } } @@ -305,7 +311,8 @@ static std::string GetOpDescriptionForDebug(Operation* inst) { os << (!first ? ", " : ""); first = false; os << named_attr.getName().getValue() << " = "; - if (auto element_attr = named_attr.getValue().dyn_cast()) { + if (auto element_attr = + mlir::dyn_cast(named_attr.getValue())) { if (element_attr.getNumElements() <= kLargeElementsAttr) { element_attr.print(os); } else { @@ -350,9 +357,9 @@ static std::string GetOpsSummary( template static bool HasValidTFLiteType(Value value, T& error_handler) { // None type is allowed to represent unspecified operands. - if (value.getType().isa()) return true; + if (mlir::isa(value.getType())) return true; - auto type = value.getType().dyn_cast(); + auto type = mlir::dyn_cast(value.getType()); if (!type) { if (auto op = value.getDefiningOp()) { error_handler.emitError() @@ -411,7 +418,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) { for (auto arg : bb.getArguments()) { if (!HasValidTFLiteType(arg, fn)) { auto elementType = getElementTypeOrSelf(arg.getType()); - if (elementType.isa()) { + if (mlir::isa(elementType)) { return fn.emitError( "function argument uses variant type. Currently, the " "variant type is not natively supported in TFLite. Please " @@ -430,10 +437,10 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) { if (inst.hasTrait()) break; for (auto result : inst.getResults()) { - if (result.getType().isa()) continue; + if (mlir::isa(result.getType())) continue; if (!HasValidTFLiteType(result, inst)) { auto elementType = getElementTypeOrSelf(result.getType()); - if (elementType.isa()) { + if (mlir::isa(elementType)) { return inst.emitError( "operand result uses variant type. Currently, the " "variant type is not natively supported in TFLite. " @@ -716,7 +723,7 @@ class Translator { // Append constant and custom op buffers at the end of the flatbuffer and // calculate the offsets - void AppendBufferData(std::string& result); + void AppendBufferData(absl::Cord& result); // Update constant & custom op buffer offsets // Return false if fail to update offset @@ -767,6 +774,11 @@ class Translator { const std::vector& results, mlir::VhloToStablehloTypeConverter& vhlo_type_converter); + std::optional> BuildVhloCompositeV1Op( + mlir::vhlo::CompositeOpV1 composite_op, + const std::vector& operands, const std::vector& results, + std::string op_name); + std::optional> BuildVhloScatterV1Op( mlir::vhlo::ScatterOpV1 scatter_op, const std::vector& operands, const std::vector& results, @@ -816,7 +828,8 @@ class Translator { // Maps buffer data to corresponding buffer index // in the idx map, the value is a pair of offset and size absl::flat_hash_map> buffer_idx_map_; - absl::flat_hash_map> buffer_data_map_; + absl::flat_hash_map buffer_data_map_; + bool buffer_data_exported_ = false; // Maps custom options data to corresponding node // Key is set to be the list of input tensor indices and list of output tensor @@ -908,7 +921,7 @@ std::optional> Translator::BuildBuffer( if (auto cst = dyn_cast(inst)) { // arith::ConstantOp have ElementAttr at this point due to validation of the // TFLite module. - attr = cst.getValue().cast(); + attr = mlir::cast(cst.getValue()); } else if (auto cst = dyn_cast(inst)) { attr = cst.getValue(); } else if (auto cst = dyn_cast(inst)) { @@ -919,10 +932,10 @@ std::optional> Translator::BuildBuffer( attr = cst.getValue(); } else if (auto cst = dyn_cast(inst)) { mlir::VhloToStablehloTypeConverter vhlo_type_converter; - auto tensor_v1_attr = cst.getValue().cast(); + auto tensor_v1_attr = mlir::cast(cst.getValue()); attr = mlir::DenseIntOrFPElementsAttr::getFromRawBuffer( - vhlo_type_converter.convertType(tensor_v1_attr.getType()) - .cast(), + mlir::cast( + vhlo_type_converter.convertType(tensor_v1_attr.getType())), tensor_v1_attr.getData()); } else if (auto cst = dyn_cast(inst)) { attr = cst.getCompressedData(); @@ -945,7 +958,7 @@ std::optional> Translator::BuildBuffer( // trouble calling ConvertToTensor(). For now, extract the tensor data from // ElementsAttr directly in this and read type from tflite::TensorType instead // of tensorflow::DataType. - auto type = value.getType().cast(); + auto type = mlir::cast(value.getType()); tflite::TensorType tflite_element_type = GetTFLiteType(type.getElementType()).value(); if (tflite_element_type == tflite::TensorType_INT4) { @@ -955,7 +968,8 @@ std::optional> Translator::BuildBuffer( } auto packed_buffer = tflite::PackInt4ValuesDensely(data); if (use_buffer_offset_) { - buffer_data_map_[index] = packed_buffer; + buffer_data_map_[index] = + std::string(packed_buffer.begin(), packed_buffer.end()); return tflite::CreateBuffer(builder_, 0, 1, 1); } else { if (IsModelBiggerThan2GB(packed_buffer.size())) { @@ -991,7 +1005,8 @@ std::optional> Translator::BuildBuffer( if (use_buffer_offset_) { std::vector buffer_data(tensor_buffer, tensor_buffer + bytes); free(tensor_buffer); - buffer_data_map_[index] = buffer_data; + buffer_data_map_[index] = + std::string(buffer_data.begin(), buffer_data.end()); return tflite::CreateBuffer(builder_, 0, 1, 1); } else { if (IsModelBiggerThan2GB(bytes)) { @@ -1007,9 +1022,7 @@ std::optional> Translator::BuildBuffer( absl::string_view tensor_data = tensor.tensor_data(); if (use_buffer_offset_) { - std::vector buffer_data(tensor_data.data(), - tensor_data.data() + tensor_data.size()); - buffer_data_map_[index] = buffer_data; + buffer_data_map_[index] = std::string(tensor_data); return tflite::CreateBuffer(builder_, 0, 1, 1); } else { if (IsModelBiggerThan2GB(tensor_data.size())) { @@ -1041,7 +1054,7 @@ int32_t Translator::UnnamedRegionToSubgraph( std::optional>> Translator::BuildTFVariantType(mlir::Type element_type) { std::vector> variant_params; - auto variant_type = element_type.dyn_cast(); + auto variant_type = mlir::dyn_cast(element_type); if (!variant_type) { return variant_params; } @@ -1070,7 +1083,7 @@ Translator::BuildTFVariantType(mlir::Type element_type) { std::optional> Translator::BuildTensorFromType( mlir::Type type, const std::string& name) { - auto tensor_type = type.cast(); + auto tensor_type = mlir::cast(type); llvm::ArrayRef shape_ref; std::vector shape; @@ -1093,15 +1106,15 @@ std::optional> Translator::BuildTensorFromType( return std::nullopt; } BufferOffset q_params = 0; - if (auto qtype = element_type.dyn_cast()) { + if (auto qtype = + mlir::dyn_cast(element_type)) { std::vector scales = {static_cast(qtype.getScale())}; std::vector zero_points = {qtype.getZeroPoint()}; q_params = tflite::CreateQuantizationParameters( builder_, /*min=*/0, /*max=*/0, builder_.CreateVector(scales), builder_.CreateVector(zero_points)); - } else if (auto qtype = - element_type - .dyn_cast()) { + } else if (auto qtype = mlir::dyn_cast( + element_type)) { std::vector mins = {static_cast(qtype.getMin())}; std::vector maxs = {static_cast(qtype.getMax())}; q_params = tflite::CreateQuantizationParameters( @@ -1120,7 +1133,7 @@ std::optional> Translator::BuildTensor( Value value, const std::string& name, unsigned buffer_idx, const std::optional>& quant_parameters) { - auto type = value.getType().cast(); + auto type = mlir::cast(value.getType()); // TFLite requires tensor shape only for the inputs and constants. // However, we output all known shapes for better round-tripping @@ -1150,9 +1163,9 @@ std::optional> Translator::BuildTensor( // Const op can have a result of dynamic shaped type (e.g. due to constant // folding), but we can still derive the shape of a constant tensor for // its attribute type. - auto tensor_attr = inst->getAttr("value").cast(); + auto tensor_attr = mlir::cast(inst->getAttr("value")); llvm::ArrayRef shape_ref = - tensor_attr.getType().cast().getShape(); + mlir::cast(tensor_attr.getType()).getShape(); if (mlir::failed(check_shape(shape_ref))) return std::nullopt; shape = std::vector(shape_ref.begin(), shape_ref.end()); @@ -1191,7 +1204,8 @@ std::optional> Translator::BuildTensor( } BufferOffset q_params; - if (auto qtype = element_type.dyn_cast()) { + if (auto qtype = + mlir::dyn_cast(element_type)) { std::vector scales = {static_cast(qtype.getScale())}; std::vector zero_points = {qtype.getZeroPoint()}; q_params = tflite::CreateQuantizationParameters( @@ -1200,8 +1214,8 @@ std::optional> Translator::BuildTensor( builder_, /*min=*/0, /*max=*/0, builder_.CreateVector(scales), builder_.CreateVector(zero_points)); } else if (auto qtype = - element_type - .dyn_cast()) { + mlir::dyn_cast( + element_type)) { std::vector scales(qtype.getScales().begin(), qtype.getScales().end()); std::vector zero_points(qtype.getZeroPoints().begin(), @@ -1339,7 +1353,9 @@ BufferOffset Translator::BuildCustomOperator( Operation* inst, mlir::TFL::CustomOp op, const std::vector& operands, const std::vector& results) { const std::string attrs = - op.getCustomOption().cast().getValue().str(); + mlir::cast(op.getCustomOption()) + .getValue() + .str(); std::vector custom_option_vector(attrs.size(), 0); memcpy(custom_option_vector.data(), attrs.data(), attrs.size()); auto opcode_index = @@ -1492,6 +1508,43 @@ uint32_t Translator::GetOpcodeIndex(const std::string& op_name, return it.first->second; } +void CreateFlexbufferVector( + const std::unique_ptr& flex_builder, + std::string& name, const mlir::Attribute& attr) { + auto start = flex_builder->StartVector(name.c_str()); + auto array = attr.cast().getValue(); + + for (int i = 0; i < array.size(); i++) { + if (llvm::isa(array[i])) { + flex_builder->Bool(name.c_str(), + array[i].cast().getValue()); + } else if (llvm::isa(attr)) { + flex_builder->String(name.c_str(), + array[i].cast().getValue().str()); + } else if (llvm::isa(array[i])) { + flex_builder->Bool(name.c_str(), + array[i].cast().getValue()); + } else if (llvm::isa(array[i])) { + flex_builder->String( + name.c_str(), + array[i].cast().getValue().str()); + } else if (llvm::isa(array[i])) { + flex_builder->Int( + name.c_str(), + array[i].cast().getValue().getSExtValue()); + } else if (llvm::isa(array[i])) { + flex_builder->Float( + name.c_str(), + array[i].cast().getValue().convertToFloat()); + + } else if (llvm::isa(array[i])) { + CreateFlexbufferVector(flex_builder, name, array[i]); + } + } + + flex_builder->EndVector(start, /*typed=*/false, /*fixed=*/false); +} + std::optional> Translator::BuildStablehloOperatorwithoutOptions( Operation* inst, const std::vector& operands, @@ -1511,7 +1564,7 @@ Translator::BuildStablehloPrecisionConfig(::mlir::ArrayAttr precisionConfig) { for (auto it = precisionConfig.begin(); it != precisionConfig.end(); it++) { precision_config_vec.push_back(static_cast( - (it->cast()).getValue())); + (mlir::cast(*it)).getValue())); } return builder_.CreateVector(precision_config_vec); } @@ -1523,7 +1576,7 @@ Translator::BuildVhloPrecisionConfigV1( auto values = precisionConfig.getValue(); for (auto it = values.begin(); it != values.end(); it++) { precision_config_vec.push_back(static_cast( - (it->cast()).getValue())); + (mlir::cast(*it)).getValue())); } return builder_.CreateVector(precision_config_vec); } @@ -1568,6 +1621,78 @@ Translator::BuildStablehloGatherOp(mlir::stablehlo::GatherOp gather_op, tflite::BuiltinOptions2_StablehloGatherOptions, gather_option.Union()); } +std::optional> +Translator::BuildVhloCompositeV1Op(mlir::vhlo::CompositeOpV1 composite_op, + const std::vector& operands, + const std::vector& results, + std::string op_name) { + uint32_t opcode_index = + GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_COMPOSITE); + + int32_t api_version = composite_op.getVersion() + .cast() + .getValue() + .getSExtValue(); + + auto name = builder_.CreateString( + composite_op.getName().cast().getValue().str()); + + auto composite_attributes = composite_op.getCompositeAttributes() + .cast(); + auto flex_builder = std::make_unique(); + size_t map_start = flex_builder->StartMap(); + + for (auto namedAttr : composite_attributes.getValue()) { + auto name = + namedAttr.first.cast().getValue().str(); + auto attr = namedAttr.second; + + if (llvm::isa(attr)) + flex_builder->Bool(name.c_str(), attr.cast().getValue()); + else if (llvm::isa(attr)) + flex_builder->String(name.c_str(), + attr.cast().getValue().str()); + else if (llvm::isa(attr)) + flex_builder->Bool(name.c_str(), + attr.cast().getValue()); + else if (llvm::isa(attr)) + flex_builder->String( + name.c_str(), attr.cast().getValue().str()); + else if (llvm::isa(attr)) + flex_builder->Int( + name.c_str(), + attr.cast().getValue().getSExtValue()); + else if (llvm::isa(attr)) + flex_builder->Float( + name.c_str(), + attr.cast().getValue().convertToFloat()); + } + + flex_builder->EndMap(map_start); + flex_builder->Finish(); + + int32_t decomposition_subgraph_index = + subgraph_index_map_[composite_op.getDecomposition() + .cast() + .getValue() + .str()]; + + auto composite_option = tflite::CreateStableHLOCompositeOptions( + builder_, name, decomposition_subgraph_index, + builder_.CreateVector(flex_builder->GetBuffer()), + tflite::CustomOptionsFormat_FLEXBUFFERS, api_version); + + return tflite::CreateOperator( + builder_, opcode_index, builder_.CreateVector(operands), + builder_.CreateVector(results), tflite::BuiltinOptions_NONE, + /*builtin_options=*/0, /*custom_options=*/0, + tflite::CustomOptionsFormat_FLEXBUFFERS, + /*mutating_variable_inputs=*/0, /*intermediates=*/0, + /*large_custom_options_offset=*/0, /*large_custom_options_size=*/0, + tflite::BuiltinOptions2_StableHLOCompositeOptions, + composite_option.Union()); +} + std::optional> Translator::BuildStablehloScatterOp(mlir::stablehlo::ScatterOp scatter_op, const std::vector& operands, @@ -1732,27 +1857,25 @@ std::optional> Translator::BuildVhloGatherV1Op( GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_GATHER); auto offset_dims = builder_.CreateVector(mlir::GetVector( - gather_op.getOffsetDims().cast(), + mlir::cast(gather_op.getOffsetDims()), vhlo_type_converter)); auto collapsed_slice_dims = builder_.CreateVector(mlir::GetVector( - gather_op.getCollapsedSliceDims().cast(), + mlir::cast(gather_op.getCollapsedSliceDims()), vhlo_type_converter)); auto start_index_map = builder_.CreateVector(mlir::GetVector( - gather_op.getStartIndexMap().cast(), + mlir::cast(gather_op.getStartIndexMap()), vhlo_type_converter)); auto slice_sizes = builder_.CreateVector(mlir::GetVector( - gather_op.getSliceSizes().cast(), + mlir::cast(gather_op.getSliceSizes()), vhlo_type_converter)); auto gather_option = tflite::CreateStablehloGatherOptions( builder_, offset_dims, collapsed_slice_dims, start_index_map, - gather_op.getIndexVectorDim() - .cast() + mlir::cast(gather_op.getIndexVectorDim()) .getValue() .getSExtValue(), slice_sizes, - gather_op.getIndicesAreSorted() - .cast() + mlir::cast(gather_op.getIndicesAreSorted()) .getValue()); return tflite::CreateOperator( @@ -1779,26 +1902,26 @@ std::optional> Translator::BuildVhloScatterV1Op( UnnamedRegionToSubgraph(&body, tflite::BuiltinOperator_STABLEHLO_SCATTER); if (subgraph_index < 0) return std::nullopt; - int64_t index_vector_dim = scatter_op.getIndexVectorDim() - .cast() - .getValue() - .getSExtValue(); - bool unique_indices = scatter_op.getUniqueIndices() - .cast() - .getValue(); - bool indices_are_sorted = scatter_op.getIndicesAreSorted() - .cast() - .getValue(); + int64_t index_vector_dim = + mlir::cast(scatter_op.getIndexVectorDim()) + .getValue() + .getSExtValue(); + bool unique_indices = + mlir::cast(scatter_op.getUniqueIndices()) + .getValue(); + bool indices_are_sorted = + mlir::cast(scatter_op.getIndicesAreSorted()) + .getValue(); auto update_window_dims = builder_.CreateVector(mlir::GetVector( - scatter_op.getUpdateWindowDims().cast(), + mlir::cast(scatter_op.getUpdateWindowDims()), vhlo_type_converter)); auto inserted_window_dims = builder_.CreateVector(mlir::GetVector( - scatter_op.getInsertedWindowDims().cast(), + mlir::cast(scatter_op.getInsertedWindowDims()), vhlo_type_converter)); auto scatter_dims_to_operand_dims = builder_.CreateVector( - mlir::GetVector(scatter_op.getScatterDimsToOperandDims() - .cast(), + mlir::GetVector(mlir::cast( + scatter_op.getScatterDimsToOperandDims()), vhlo_type_converter)); auto options = tflite::CreateStablehloScatterOptions( @@ -1826,20 +1949,22 @@ Translator::BuildVhloReduceWindowV1Op( uint32_t opcode_index = GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_REDUCE_WINDOW); - auto window_dimensions = builder_.CreateVector(mlir::GetVector( - reduce_window_op.getWindowDimensions().cast(), - vhlo_type_converter)); + auto window_dimensions = builder_.CreateVector( + mlir::GetVector(mlir::cast( + reduce_window_op.getWindowDimensions()), + vhlo_type_converter)); auto window_strides = builder_.CreateVector(mlir::GetVector( - reduce_window_op.getWindowStrides().cast(), + mlir::cast(reduce_window_op.getWindowStrides()), vhlo_type_converter)); auto base_dilations = builder_.CreateVector(mlir::GetVector( - reduce_window_op.getBaseDilations().cast(), - vhlo_type_converter)); - auto window_dilations = builder_.CreateVector(mlir::GetVector( - reduce_window_op.getWindowDilations().cast(), + mlir::cast(reduce_window_op.getBaseDilations()), vhlo_type_converter)); + auto window_dilations = builder_.CreateVector( + mlir::GetVector(mlir::cast( + reduce_window_op.getWindowDilations()), + vhlo_type_converter)); auto padding = builder_.CreateVector(mlir::GetVector( - reduce_window_op.getPadding().cast(), + mlir::cast(reduce_window_op.getPadding()), vhlo_type_converter)); auto& body = reduce_window_op.getBody(); int32_t subgraph_index = UnnamedRegionToSubgraph( @@ -1870,8 +1995,7 @@ Translator::BuildVhloRngBitGeneratorV1Op( uint32_t opcode_index = GetOpcodeIndex( op_name, tflite::BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR); tflite::RngAlgorithm algorithm = tflite::RngAlgorithm_DEFAULT; - switch (rng_op.getRngAlgorithm() - .cast() + switch (mlir::cast(rng_op.getRngAlgorithm()) .getValue()) { case mlir::vhlo::RngAlgorithmV1::THREE_FRY: algorithm = tflite::RngAlgorithm_THREEFRY; @@ -1904,13 +2028,13 @@ std::optional> Translator::BuildVhloPadV1Op( GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_PAD); auto edge_padding_low = builder_.CreateVector(mlir::GetVector( - pad_op.getEdgePaddingLow().cast(), + mlir::cast(pad_op.getEdgePaddingLow()), vhlo_type_converter)); auto edge_padding_high = builder_.CreateVector(mlir::GetVector( - pad_op.getEdgePaddingHigh().cast(), + mlir::cast(pad_op.getEdgePaddingHigh()), vhlo_type_converter)); auto interior_padding = builder_.CreateVector(mlir::GetVector( - pad_op.getInteriorPadding().cast(), + mlir::cast(pad_op.getInteriorPadding()), vhlo_type_converter)); auto pad_option = tflite::CreateStablehloPadOptions( @@ -2031,6 +2155,10 @@ std::optional> Translator::BuildOperator( if (auto vhlo_op = llvm::dyn_cast(inst)) { return BuildVhloPadV1Op(vhlo_op, operands, results, vhlo_type_converter); } + if (auto vhlo_op = llvm::dyn_cast(inst)) { + return BuildVhloCompositeV1Op(vhlo_op, operands, results, + inst->getName().getStringRef().str()); + } // for ops don't have kernels, only serialize when conversion is set to // true if (convert_stablehlo_) { @@ -2139,10 +2267,10 @@ std::optional> Translator::BuildOperator( GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_IOTA); auto iota_option = tflite::CreateStablehloIotaOptions( - builder_, vhlo_op.getIotaDimension() - .cast() - .getValue() - .getSExtValue()); + builder_, + mlir::cast(vhlo_op.getIotaDimension()) + .getValue() + .getSExtValue()); return tflite::CreateOperator( builder_, opcode_index, builder_.CreateVector(operands), @@ -2156,7 +2284,7 @@ std::optional> Translator::BuildOperator( op_name, tflite::BuiltinOperator_STABLEHLO_DYNAMIC_SLICE); auto slice_sizes = builder_.CreateVector(mlir::GetVector( - vhlo_op.getSliceSizes().cast(), + mlir::cast(vhlo_op.getSliceSizes()), vhlo_type_converter)); auto dynamic_slice_option = @@ -2179,13 +2307,13 @@ std::optional> Translator::BuildOperator( tflite::StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_NOTYPE; if (compare_type_attr) compare_type = static_cast( - compare_type_attr.cast() + mlir::cast(compare_type_attr) .getValue()); auto compare_option = tflite::CreateStablehloCompareOptions( builder_, static_cast( - vhlo_op.getComparisonDirection() - .cast() + mlir::cast( + vhlo_op.getComparisonDirection()) .getValue()), compare_type); @@ -2202,10 +2330,10 @@ std::optional> Translator::BuildOperator( op_name, tflite::BuiltinOperator_STABLEHLO_CONCATENATE); auto concat_option = tflite::CreateStablehloConcatenateOptions( - builder_, vhlo_op.getDimension() - .cast() - .getValue() - .getSExtValue()); + builder_, + mlir::cast(vhlo_op.getDimension()) + .getValue() + .getSExtValue()); return tflite::CreateOperator( builder_, opcode_index, builder_.CreateVector(operands), @@ -2220,13 +2348,13 @@ std::optional> Translator::BuildOperator( GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_SLICE); auto start_indices = builder_.CreateVector((mlir::GetVector( - vhlo_op.getStartIndicesAttr().cast(), + mlir::cast(vhlo_op.getStartIndicesAttr()), vhlo_type_converter))); auto limit_indices = builder_.CreateVector(mlir::GetVector( - vhlo_op.getLimitIndicesAttr().cast(), + mlir::cast(vhlo_op.getLimitIndicesAttr()), vhlo_type_converter)); auto strides = builder_.CreateVector(mlir::GetVector( - vhlo_op.getStridesAttr().cast(), + mlir::cast(vhlo_op.getStridesAttr()), vhlo_type_converter)); auto slice_option = tflite::CreateStablehloSliceOptions( @@ -2245,63 +2373,64 @@ std::optional> Translator::BuildOperator( op_name, tflite::BuiltinOperator_STABLEHLO_CONVOLUTION); auto window_strides = builder_.CreateVector(mlir::GetVector( - vhlo_op.getWindowStrides().cast(), + mlir::cast(vhlo_op.getWindowStrides()), vhlo_type_converter)); auto padding = builder_.CreateVector(mlir::GetVector( - vhlo_op.getPadding().cast(), + mlir::cast(vhlo_op.getPadding()), vhlo_type_converter)); auto lhs_dialation = builder_.CreateVector(mlir::GetVector( - vhlo_op.getLhsDilation().cast(), + mlir::cast(vhlo_op.getLhsDilation()), vhlo_type_converter)); auto rhs_dialation = builder_.CreateVector(mlir::GetVector( - vhlo_op.getRhsDilation().cast(), + mlir::cast(vhlo_op.getRhsDilation()), vhlo_type_converter)); auto window_reversal = builder_.CreateVector(mlir::GetVector( - vhlo_op.getWindowReversal().cast(), + mlir::cast(vhlo_op.getWindowReversal()), vhlo_type_converter)); - auto input_batch_dimension = vhlo_op.getInputBatchDimension() - .cast() + auto input_batch_dimension = mlir::cast( + vhlo_op.getInputBatchDimension()) .getValue() .getSExtValue(); - auto input_feature_dimension = vhlo_op.getInputFeatureDimension() - .cast() + auto input_feature_dimension = mlir::cast( + vhlo_op.getInputFeatureDimension()) .getValue() .getSExtValue(); auto kernel_input_feature_dimension = - vhlo_op.getKernelInputFeatureDimension() - .cast() + mlir::cast( + vhlo_op.getKernelInputFeatureDimension()) .getValue() .getSExtValue(); auto kernel_output_feature_dimension = - vhlo_op.getKernelOutputFeatureDimension() - .cast() + mlir::cast( + vhlo_op.getKernelOutputFeatureDimension()) .getValue() .getSExtValue(); - auto output_batch_dimension = vhlo_op.getOutputBatchDimension() - .cast() + auto output_batch_dimension = mlir::cast( + vhlo_op.getOutputBatchDimension()) .getValue() .getSExtValue(); - auto output_feature_dimension = vhlo_op.getOutputFeatureDimension() - .cast() + auto output_feature_dimension = mlir::cast( + vhlo_op.getOutputFeatureDimension()) .getValue() .getSExtValue(); auto kernel_spatial_dimensions = builder_.CreateVector( - mlir::GetVector(vhlo_op.getKernelSpatialDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getKernelSpatialDimensions()), vhlo_type_converter)); auto output_spatial_dimension = builder_.CreateVector( - mlir::GetVector(vhlo_op.getOutputSpatialDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getOutputSpatialDimensions()), vhlo_type_converter)); auto input_spatial_dimension = builder_.CreateVector( - mlir::GetVector(vhlo_op.getInputSpatialDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getInputSpatialDimensions()), vhlo_type_converter)); BufferOffset> precision_config = 0; if (vhlo_op.getPrecisionConfig()) { precision_config = BuildVhloPrecisionConfigV1( - vhlo_op.getPrecisionConfig().dyn_cast()); + mlir::dyn_cast( + vhlo_op.getPrecisionConfig())); } auto convolution_option = tflite::CreateStablehloConvolutionOptions( @@ -2311,12 +2440,11 @@ std::optional> Translator::BuildOperator( kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimension, - vhlo_op.getFeatureGroupCount() - .cast() + mlir::cast( + vhlo_op.getFeatureGroupCount()) .getValue() .getSExtValue(), - vhlo_op.getBatchGroupCount() - .cast() + mlir::cast(vhlo_op.getBatchGroupCount()) .getValue() .getSExtValue(), precision_config); @@ -2334,8 +2462,8 @@ std::optional> Translator::BuildOperator( op_name, tflite::BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM); auto broadcast_dimensions = builder_.CreateVector( - mlir::GetVector(vhlo_op.getBroadcastDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getBroadcastDimensions()), vhlo_type_converter)); auto broadcast_option = tflite::CreateStablehloBroadcastInDimOptions( @@ -2354,8 +2482,8 @@ std::optional> Translator::BuildOperator( uint32_t opcode_index = GetOpcodeIndex( op_name, tflite::BuiltinOperator_STABLEHLO_CUSTOM_CALL); auto op_api_version = - vhlo_op.getApiVersion() - .cast() + mlir::cast( + vhlo_op.getApiVersion()) .getValue(); int32_t api_version = 0; if (op_api_version == @@ -2371,16 +2499,14 @@ std::optional> Translator::BuildOperator( API_VERSION_STATUS_RETURNING_UNIFIED) api_version = 3; - auto call_target_name = - builder_.CreateString(vhlo_op.getCallTargetName() - .cast() - .getValue() - .str()); - auto backend_config = - builder_.CreateString(vhlo_op.getBackendConfig() - .cast() - .getValue() - .str()); + auto call_target_name = builder_.CreateString( + mlir::cast(vhlo_op.getCallTargetName()) + .getValue() + .str()); + auto backend_config = builder_.CreateString( + mlir::cast(vhlo_op.getBackendConfig()) + .getValue() + .str()); // building the computation info auto flex_builder = std::make_unique(); size_t map_start = flex_builder->StartMap(); @@ -2393,25 +2519,25 @@ std::optional> Translator::BuildOperator( if (name == "call_target_name" || name == "backend_config") continue; if (llvm::isa(attr)) flex_builder->Bool(name.c_str(), - attr.cast().getValue()); + mlir::cast(attr).getValue()); if (llvm::isa(attr)) flex_builder->String( - name.c_str(), attr.cast().getValue().str()); + name.c_str(), + mlir::cast(attr).getValue().str()); if (llvm::isa(attr)) flex_builder->Bool( name.c_str(), - attr.cast().getValue()); + mlir::cast(attr).getValue()); if (llvm::isa(attr)) flex_builder->String( name.c_str(), - attr.cast().getValue().str()); + mlir::cast(attr).getValue().str()); } flex_builder->EndMap(map_start); flex_builder->Finish(); auto custom_call_option = tflite::CreateStablehloCustomCallOptions( builder_, call_target_name, - vhlo_op.getHasSideEffect() - .cast<::mlir::vhlo::BooleanV1Attr>() + mlir::cast<::mlir::vhlo::BooleanV1Attr>(vhlo_op.getHasSideEffect()) .getValue(), backend_config, api_version, 0, builder_.CreateVector(flex_builder->GetBuffer())); @@ -2429,7 +2555,7 @@ std::optional> Translator::BuildOperator( GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_REDUCE); auto dimension = builder_.CreateVector(mlir::GetVector( - vhlo_op.getDimensions().cast(), + mlir::cast(vhlo_op.getDimensions()), vhlo_type_converter)); auto& body = vhlo_op.getBody(); int32_t subgraph_index = UnnamedRegionToSubgraph( @@ -2452,26 +2578,27 @@ std::optional> Translator::BuildOperator( op_name, tflite::BuiltinOperator_STABLEHLO_DOT_GENERAL); auto lhs_batching_dimensions = builder_.CreateVector( - mlir::GetVector(vhlo_op.getLhsBatchingDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getLhsBatchingDimensions()), vhlo_type_converter)); auto rhs_batching_dimensions = builder_.CreateVector( - mlir::GetVector(vhlo_op.getRhsBatchingDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getRhsBatchingDimensions()), vhlo_type_converter)); auto lhs_contracting_dimensions = builder_.CreateVector( - mlir::GetVector(vhlo_op.getLhsContractingDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getLhsContractingDimensions()), vhlo_type_converter)); auto rhs_contracting_dimensions = builder_.CreateVector( - mlir::GetVector(vhlo_op.getRhsContractingDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getRhsContractingDimensions()), vhlo_type_converter)); BufferOffset> precision_config = 0; if (vhlo_op.getPrecisionConfig()) { - precision_config = BuildVhloPrecisionConfigV1( - vhlo_op.getPrecisionConfig().cast()); + precision_config = + BuildVhloPrecisionConfigV1(mlir::cast( + vhlo_op.getPrecisionConfig())); } auto dot_geneoral_option = tflite::CreateStablehloDotGeneralOptions( @@ -2497,11 +2624,11 @@ std::optional> Translator::BuildOperator( auto sort_option = tflite::CreateStablehloSortOptions( builder_, - vhlo_op.getDimension() - .cast() + mlir::cast(vhlo_op.getDimension()) .getValue() .getSExtValue(), - vhlo_op.getIsStable().cast().getValue(), + mlir::cast(vhlo_op.getIsStable()) + .getValue(), comparator_subgraph_index); return tflite::CreateOperator( @@ -2543,7 +2670,7 @@ std::optional> Translator::BuildOperator( auto transpose_option = tflite::CreateStablehloTransposeOptions( builder_, builder_.CreateVector(mlir::GetVector( - vhlo_op.getPermutation().cast(), + mlir::cast(vhlo_op.getPermutation()), vhlo_type_converter))); return tflite::CreateOperator( @@ -2669,7 +2796,8 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { llvm::SmallVector input_names; llvm::SmallVector output_names; - if (auto str = dict_attr.get("inputs").dyn_cast_or_null()) { + if (auto str = + mlir::dyn_cast_or_null(dict_attr.get("inputs"))) { str.getValue().split(input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); if (input_names.size() != fn.getNumArguments()) { @@ -2683,7 +2811,7 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { } if (auto str = - dict_attr.get("outputs").dyn_cast_or_null()) { + mlir::dyn_cast_or_null(dict_attr.get("outputs"))) { str.getValue().split(output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); auto term = fn.back().getTerminator(); @@ -2708,13 +2836,14 @@ bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) { BufferOffset Translator::GetQuantizationForQuantStatsOpOutput( mlir::quantfork::StatisticsOp stats_op) { - auto layer_stats = stats_op.getLayerStats().cast(); + auto layer_stats = + mlir::cast(stats_op.getLayerStats()); std::optional axis_stats = stats_op.getAxisStats(); std::optional axis = stats_op.getAxis(); std::vector mins, maxs; mlir::DenseFPElementsAttr min_max_attr = axis_stats.has_value() - ? axis_stats.value().cast() + ? mlir::cast(axis_stats.value()) : layer_stats; for (const auto& index_and_value : @@ -2749,7 +2878,7 @@ std::optional> Translator::BuildSubGraph( auto build_tensor_and_buffer = [&](Value value, const int subgraph_index, const std::string& tensor_name) { // NoneType represents optional and may be skipped here. - if (value.getType().isa()) { + if (mlir::isa(value.getType())) { return true; } @@ -2833,7 +2962,8 @@ std::optional> Translator::BuildSubGraph( "effective_hidden_scale_intermediate"}; for (const std::string& intermediate : intermediate_names) { auto intermediate_attr = inst.getAttr(intermediate); - if (auto attr = intermediate_attr.dyn_cast_or_null()) { + if (auto attr = + mlir::dyn_cast_or_null(intermediate_attr)) { Type qtype = attr.getValue(); auto tensor_or = BuildTensorFromType( qtype, name_mapper_.GetUniqueName(intermediate).str()); @@ -2879,7 +3009,7 @@ std::optional> Translator::BuildSubGraph( std::vector operands; operands.reserve(real_inst->getNumOperands()); for (auto operand : real_inst->getOperands()) { - if (operand.getType().isa()) + if (mlir::isa(operand.getType())) operands.push_back(kTfLiteOptionalTensor); else if (auto stats_op = llvm::dyn_cast_or_null( @@ -2960,7 +3090,7 @@ Translator::CreateMetadataVector() { for (const auto& named_attr : dict_attr) { StringRef name = named_attr.getName(); mlir::Attribute attr = named_attr.getValue(); - if (auto content = attr.dyn_cast()) { + if (auto content = mlir::dyn_cast(attr)) { metadata.push_back(BuildMetadata(name, content.getValue())); } else { module_.emitError( @@ -3008,7 +3138,7 @@ Translator::CreateMetadataVector() { llvm::SmallVector GetStringsFromAttrWithSeparator( mlir::DictionaryAttr attr, const std::string& attr_key) { llvm::SmallVector result; - if (auto str = attr.get(attr_key).dyn_cast_or_null()) { + if (auto str = mlir::dyn_cast_or_null(attr.get(attr_key))) { str.getValue().split(result, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); } @@ -3027,9 +3157,11 @@ std::vector GetStringsFromDictionaryAttr( auto attrs = arg_attr.getValue(); for (const auto attr : attrs) { if (attr.getName() == attr_name) { - auto array_attr = attr.getValue().dyn_cast_or_null(); + auto array_attr = + mlir::dyn_cast_or_null(attr.getValue()); if (!array_attr || array_attr.empty()) continue; - auto string_attr = array_attr[0].dyn_cast_or_null(); + auto string_attr = + mlir::dyn_cast_or_null(array_attr[0]); if (!string_attr) continue; result.push_back(string_attr.getValue().str()); } @@ -3112,7 +3244,7 @@ std::vector BuildSignaturedef( auto unique_name = std::string(name_mapper.GetUniqueName(operand.get())); result[0].outputs[sig_def_outputs[i]] = unique_name; } - if (auto name_attr = exported_name[0].dyn_cast_or_null()) + if (auto name_attr = mlir::dyn_cast_or_null(exported_name[0])) result[0].signature_key = name_attr.getValue().str(); result[0].subgraph_index = subgraph_index; return result; @@ -3197,17 +3329,19 @@ std::optional Translator::Translate( op_or_arg_name_mapper = &default_op_or_arg_name_mapper; if (!UpdateEntryFunction(module)) return std::nullopt; if (!IsValidTFLiteMlirModule(module)) return std::nullopt; - Translator translator(module, toco_flags, tags, op_or_arg_name_mapper, - metadata, custom_option_alignment); - translator.convert_stablehlo_ = serialize_stablehlo_ops; - auto ret = translator.TranslateInternal(); - if (translator.require_use_buffer_offset_) { + auto translator = std::unique_ptr( + new Translator(module, toco_flags, tags, op_or_arg_name_mapper, metadata, + custom_option_alignment)); + translator->convert_stablehlo_ = serialize_stablehlo_ops; + auto ret = translator->TranslateInternal(); + if (translator->require_use_buffer_offset_) { + ret = std::nullopt; auto new_toco_flags = toco_flags; new_toco_flags.set_use_buffer_offset(true); - Translator new_translator(module, new_toco_flags, tags, - op_or_arg_name_mapper, metadata, - custom_option_alignment); - return new_translator.TranslateInternal(); + translator = std::unique_ptr( + new Translator(module, new_toco_flags, tags, op_or_arg_name_mapper, + metadata, custom_option_alignment)); + return translator->TranslateInternal(); } return ret; } @@ -3453,63 +3587,91 @@ std::optional Translator::TranslateInternal() { } } - auto result = - std::string(reinterpret_cast(builder_.GetBufferPointer()), - builder_.GetSize()); + absl::Cord result; + auto fbs = absl::string_view( + reinterpret_cast(builder_.GetBufferPointer()), + builder_.GetSize()); + result.Append(fbs); // Return serialized string for the built FlatBuffer. if (use_buffer_offset_) { + // Pad to be 16 bytes aligned + { + std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0'); + result.Append(std::move(pad)); + } AppendBufferData(result); - auto mutable_model = tflite::GetMutableModel(result.data()); + std::string result_str = std::string(std::move(result)); + auto mutable_model = tflite::GetMutableModel(result_str.data()); bool ret = UpdateBufferOffsets(mutable_model); if (!ret) { return std::nullopt; } - return result; + return result_str; } - return result; + return std::string(result); } -void Translator::AppendBufferData(std::string& result) { +void Translator::AppendBufferData(absl::Cord& result) { std::unordered_map> hashcode_to_pos; - // Pad to be 16 bytes aligned - while (result.size() % 16 != 0) result += '\0'; - for (auto& it : buffer_data_map_) { - auto buffer = std::string(it.second.begin(), it.second.end()); - int64_t index = it.first; + // Buffer data should be exported only once. + assert(!buffer_data_exported_); + + auto it = buffer_data_map_.begin(); + while (it != buffer_data_map_.end()) { + std::string buffer = it->second; + int64_t index = it->first; int64_t offset = result.size(); - int64_t size = it.second.size(); + int64_t size = buffer.size(); uint64_t hash = tsl::Fingerprint64(buffer); if (hashcode_to_pos.find(hash) == hashcode_to_pos.end()) { hashcode_to_pos[hash] = std::make_pair(offset, size); buffer_idx_map_[index] = std::make_pair(offset, size); - result += std::string(it.second.begin(), it.second.end()); - // Pad to be 16 bytes aligned - while (result.size() % 16 != 0) result += '\0'; + result.Append(std::move(buffer)); + // Pad to be 16 bytes aligned. + { + std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0'); + result.Append(std::move(pad)); + } } else { // only update offset/index. buffer_idx_map_[index] = hashcode_to_pos[hash]; } + buffer_data_map_.erase(it); + it = buffer_data_map_.begin(); + buffer_data_exported_ = true; } // pad 16 bytes for the last buffer for XNNPack - result += "\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"; + result.Append("\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); // pad to be 16 bytes aligned - while (result.size() % 16 != 0) result += '\0'; + { + std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0'); + result.Append(std::move(pad)); + } for (auto& it : custom_op_data_map_) { - while (result.size() % 16 != 0) result += '\0'; + { + std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0'); + result.Append(std::move(pad)); + } if (custom_option_alignment_.has_value()) { - while (result.size() % custom_option_alignment_.value() != 0) - result += '\0'; + { + auto alignment = custom_option_alignment_.value(); + std::string pad(alignment - result.size() % alignment, '\0'); + result.Append(std::move(pad)); + } } auto buffer = std::string(it.second.begin(), it.second.end()); int64_t offset = result.size(); int64_t size = it.second.size(); custom_op_idx_map_[it.first] = std::make_pair(offset, size); - result += buffer; + result.Append(std::move(buffer)); } // pad to be 16 bytes aligned - while (result.size() % 16 != 0) result += '\0'; + { + std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0'); + result.Append(std::move(pad)); + } } bool Translator::UpdateBufferOffsets(tflite::Model* mutable_model) { @@ -3568,8 +3730,8 @@ BufferOffset Translator::BuildSparsityParameters( std::vector> fb_dim_metadata( dim_size); for (int i = 0; i < dim_size; i++) { - const auto dim_metadata = - s_attr.getDimMetadata()[i].dyn_cast(); + const auto dim_metadata = mlir::dyn_cast( + s_attr.getDimMetadata()[i]); if (dim_metadata.getFormat().getValue() == mlir::TFL::DimensionType::DENSE) { fb_dim_metadata[i] = tflite::CreateDimensionMetadata( diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 0d477b51b6d467..bcc0244194dccd 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -77,6 +77,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/offset_buffer.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" @@ -96,7 +97,6 @@ limitations under the License. #include "tensorflow/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/graph_info.h" #include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/schema/mutable/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -770,6 +770,20 @@ StatusOr ConvertOp( mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs); mlir::BuiltinOptions2ToAttributes(op.builtin_options_2, builder, attrs); } + + if (builtin_code == tflite::BuiltinOperator_STABLEHLO_COMPOSITE) { + auto composite_options = op.builtin_options_2.AsStableHLOCompositeOptions(); + std::string decomposition = ""; + if (composite_options->decomposition_subgraph_index > -1) { + decomposition = + func_names.at(composite_options->decomposition_subgraph_index); + } + + attrs.emplace_back(builder.getNamedAttr( + "decomposition", + mlir::vhlo::StringV1Attr::get(builder.getContext(), decomposition))); + } + op_state.addAttributes(attrs); // Handle the conversion from subgraph index to functions for If and While. We diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index f72ef1f9641d48..3dfc21f5c2a07c 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/strings/str_cat.h" @@ -41,9 +42,11 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" @@ -52,7 +55,6 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" -#include "tensorflow/lite/schema/mutable/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tsl/platform/status.h" @@ -176,7 +178,7 @@ ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray, std::vector intVec; intVec.reserve(attrArray.getValue().size()); for (auto attr : attrArray.getValue()) { - intVec.push_back(attr.cast().getInt()); + intVec.push_back(mlir::cast(attr).getInt()); } return builder->CreateVector(intVec); } @@ -188,7 +190,7 @@ ConvertF32ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray, floatVec.reserve(attrArray.getValue().size()); for (auto attr : attrArray.getValue()) { floatVec.push_back( - attr.cast().getValue().convertToFloat()); + mlir::cast(attr).getValue().convertToFloat()); } return builder->CreateVector(floatVec); } @@ -301,6 +303,21 @@ static mlir::Attribute BuildVhloArrayV1Attr(std::vector value, return mlir::vhlo::ArrayV1Attr::get(builder.getContext(), value); } +static mlir::Attribute BuildVhloDictionaryV1Attr( + std::vector> value, + mlir::Builder builder) { + return mlir::vhlo::DictionaryV1Attr::get(builder.getContext(), value); +} + +static mlir::Attribute BuildVhloFloatV1Attr(float value, + mlir::Builder builder) { + mlir::StablehloVhloTypeConverter type_converter; + auto vhlo_type = + type_converter.convertType(builder.getF32FloatAttr(value).getType()); + return mlir::vhlo::FloatV1Attr::get(builder.getContext(), vhlo_type, + ::llvm::APFloat(value)); +} + static mlir::Attribute BuildRankedTensorAttr(std::vector shape, std::vector value, mlir::Builder builder) { @@ -327,8 +344,8 @@ static mlir::Attribute BuildVhloTensorV1Attr(std::vector shape, std::vector value, mlir::Builder builder) { mlir::StablehloVhloTypeConverter type_converter; - auto builtin_attr = BuildRankedTensorAttr(shape, value, builder) - .dyn_cast(); + auto builtin_attr = mlir::dyn_cast( + BuildRankedTensorAttr(shape, value, builder)); auto vhlo_type = type_converter.convertType(builtin_attr.getType()); return mlir::vhlo::TensorV1Attr::get(builder.getContext(), vhlo_type, builtin_attr.getRawData()); @@ -338,8 +355,8 @@ static mlir::Attribute BuildVhloTensorV1Attr(std::vector shape, std::vector value, mlir::Builder builder) { mlir::StablehloVhloTypeConverter type_converter; - auto builtin_attr = BuildRankedTensorAttr(shape, value, builder) - .dyn_cast(); + auto builtin_attr = mlir::dyn_cast( + BuildRankedTensorAttr(shape, value, builder)); auto vhlo_type = type_converter.convertType(builtin_attr.getType()); return mlir::vhlo::TensorV1Attr::get(builder.getContext(), vhlo_type, builtin_attr.getRawData()); @@ -416,6 +433,33 @@ static mlir::Attribute BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value, return mlir::TFL::MirrorPaddingTypeAttr::get(builder.getContext(), padding); } +static std::vector BuildAttributeVectorFromFlatbuffer( + flexbuffers::Vector flatbuffer_vector, mlir::Builder builder) { + std::vector mlir_vector; + + for (int i = 0; i < flatbuffer_vector.size(); ++i) { + auto value = flatbuffer_vector[i]; + + if (value.IsBool()) { + mlir_vector.push_back(BuildVhloBooleanV1Attr(value.AsBool(), builder)); + } else if (value.IsString()) { + mlir_vector.push_back( + BuildVhloStringV1Attr(value.AsString().str(), builder)); + } else if (value.IsInt()) { + mlir_vector.push_back(BuildVhloIntV1Attr(value.AsInt64(), builder)); + } else if (value.IsFloat()) { + mlir_vector.push_back(BuildVhloFloatV1Attr(value.AsFloat(), builder)); + } else if (value.IsVector()) { + std::vector nested_mlir_vector = + BuildAttributeVectorFromFlatbuffer(value.AsVector(), builder); + mlir_vector.push_back( + BuildVhloArrayV1Attr(std::move(nested_mlir_vector), builder)); + } + } + + return mlir_vector; +} + static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value, mlir::Builder builder) { const char* option_name = tflite::EnumNamePadding(value); @@ -613,8 +657,6 @@ void BuiltinOptions2ToAttributesManual( bool has_side_effect_set = false; const flexbuffers::Map& computation_map = flexbuffers::GetRoot(op->custom_attributes).AsMap(); - std::vector symbol_vec; - symbol_vec.reserve(computation_map.size()); const auto& keys = computation_map.Keys(); for (size_t i = 0; i < keys.size(); ++i) { const auto key = keys[i].AsKey(); @@ -638,6 +680,61 @@ void BuiltinOptions2ToAttributesManual( "has_side_effect", BuildVhloBooleanV1Attr(false, builder))); return; } + if (const auto* op = op_union.AsStableHLOCompositeOptions()) { + attributes.emplace_back( + builder.getNamedAttr("name", BuildVhloStringV1Attr(op->name, builder))); + + attributes.emplace_back(builder.getNamedAttr( + "version", BuildVhloIntV1Attr(op->version, builder))); + + auto composite_attribute_pairs = + std::vector>(); + + auto composite_attributes = + flexbuffers::GetRoot(op->composite_attributes).AsMap(); + + const auto& keys = composite_attributes.Keys(); + for (size_t i = 0; i < keys.size(); ++i) { + const auto key = keys[i].AsKey(); + const auto& value = composite_attributes[key]; + + std::pair composite_attribute_pair; + composite_attribute_pair.first = BuildVhloStringV1Attr(key, builder); + + if (value.IsBool()) { + composite_attribute_pair.second = + BuildVhloBooleanV1Attr(value.AsBool(), builder); + } + if (value.IsString()) { + composite_attribute_pair.second = + BuildVhloStringV1Attr(value.AsString().str(), builder); + } + if (value.IsInt()) { + composite_attribute_pair.second = + BuildVhloIntV1Attr(value.AsInt64(), builder); + } + if (value.IsFloat()) { + composite_attribute_pair.second = + BuildVhloFloatV1Attr(value.AsFloat(), builder); + } + + if (value.IsVector()) { + std::vector mlir_vector = + BuildAttributeVectorFromFlatbuffer(value.AsVector(), builder); + + composite_attribute_pair.second = + BuildVhloArrayV1Attr(std::move(mlir_vector), builder); + } + + composite_attribute_pairs.emplace_back(composite_attribute_pair); + } + + attributes.emplace_back(builder.getNamedAttr( + "composite_attributes", + BuildVhloDictionaryV1Attr(std::move(composite_attribute_pairs), + builder))); + return; + } if (const auto* op = op_union.AsStablehloPadOptions()) { std::vector shape = { static_cast(op->edge_padding_low.size())}; diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h index 381f2a4c024549..64865eb77b5c43 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h @@ -30,12 +30,13 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloTypes.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" -#include "tensorflow/lite/schema/mutable/schema_generated.h" namespace mlir { @@ -65,7 +66,7 @@ class StablehloVhloTypeConverter : public mlir::vhlo::VhloTypeConverter { return attr; if (auto stablehloAttr = - attr.dyn_cast_or_null()) { + mlir::dyn_cast_or_null(attr)) { return mlir::vhlo::TypeExtensionsV1Attr::get(stablehloAttr.getContext(), stablehloAttr.getBounds()); } @@ -88,7 +89,8 @@ class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter { } Attribute convertEncoding(Attribute attr) const final { - if (auto vhloAttr = attr.dyn_cast_or_null()) { + if (auto vhloAttr = + mlir::dyn_cast_or_null(attr)) { return stablehlo::TypeExtensionsAttr::get(vhloAttr.getContext(), vhloAttr.getBounds()); } @@ -296,8 +298,8 @@ static inline std::vector GetVector( vhlo::TensorV1Attr elements, mlir::vhlo::VhloTypeConverter &vhlo_type_converter) { return GetOptionalVector(mlir::DenseIntElementsAttr::getFromRawBuffer( - vhlo_type_converter.convertType(elements.getType()) - .cast(), + mlir::cast( + vhlo_type_converter.convertType(elements.getType())), elements.getData())); } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc b/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc index b3e7e8e633e0da..df28f501ef7656 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc @@ -25,7 +25,7 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/minireflect.h" // from @flatbuffers -#include "tensorflow/lite/schema/reflection/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/reflection/schema_generated.h" #if FLATBUFFERS_LITTLEENDIAN == 0 #include "tensorflow/lite/core/model_builder.h" #endif diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td index 3ae87672d0a99c..e2e10e2e712131 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td @@ -38,7 +38,6 @@ def TFL_Dialect : Dialect { let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; - let usePropertiesForAttributes = 0; let extraClassDeclaration = [{ ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 8ac81939d0d4de..1633820bb5bd5e 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -86,25 +86,13 @@ namespace { ParseResult parseOneResultSameOperandTypeOp(OpAsmParser& parser, OperationState& result) { SmallVector ops; - Type type; // If the operand list is in-between parentheses, then we have a generic form. // (see the fallback in `printOneResultOp`). - SMLoc loc = parser.getCurrentLocation(); if (!parser.parseOptionalLParen()) { - if (parser.parseOperandList(ops) || parser.parseRParen() || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColon() || parser.parseType(type)) - return failure(); - auto fnType = type.dyn_cast(); - if (!fnType) { - parser.emitError(loc, "expected function type"); - return failure(); - } - if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) - return failure(); - result.addTypes(fnType.getResults()); - return success(); + if (parser.parseOperandList(ops) || parser.parseRParen()) return failure(); + return parser.parseGenericOperationAfterOpName(result, ops); } + Type type; return failure(parser.parseOperandList(ops) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || @@ -1046,9 +1034,9 @@ mlir::LogicalResult CustomOp::verify() { LogicalResult CustomTfOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attr, OpaqueProperties, RegionRange ranges, + DictionaryAttr attr, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - CustomTfOpAdaptor op(operands, attr, {}, ranges); + CustomTfOpAdaptor op(operands, attr, properties, regions); if (op.getRegions().empty()) return success(); auto* real_op = &op.getBody().front().front(); @@ -1391,9 +1379,9 @@ static LogicalResult ComputeConvWindowedOutputSize( LogicalResult Conv2DOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attr, OpaqueProperties, RegionRange, + DictionaryAttr attr, OpaqueProperties properties, RegionRange, SmallVectorImpl& inferredReturnTypes) { - Conv2DOpAdaptor op(operands, attr); + Conv2DOpAdaptor op(operands, attr, properties); const Value input = op.getInput(); const Value filter = op.getFilter(); @@ -2072,9 +2060,9 @@ mlir::LogicalResult ReshapeOp::verify() { LogicalResult ReshapeOp::inferReturnTypes( MLIRContext* context, std::optional location, ValueRange operands, - DictionaryAttr attr, OpaqueProperties, RegionRange, + DictionaryAttr attr, OpaqueProperties properties, RegionRange, SmallVectorImpl& inferredReturnTypes) { - ReshapeOpAdaptor op(operands, attr); + ReshapeOpAdaptor op(operands, attr, properties); const Value input = op.getInput(); const Value shape = op.getShape(); @@ -2449,9 +2437,9 @@ void FakeQuantOp::getCanonicalizationPatterns(RewritePatternSet& results, LogicalResult UnpackOp::inferReturnTypes( MLIRContext* context, std::optional loc, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - UnpackOpAdaptor op(operands, attributes); + UnpackOpAdaptor op(operands, attributes, properties); // TODO(jpienaar): Refactor verify if (failed(op.verify(loc.has_value() ? *loc : UnknownLoc::get(context)))) return failure(); @@ -2810,7 +2798,7 @@ mlir::LogicalResult UnidirectionalSequenceLSTMOp::verify() { LogicalResult UnidirectionalSequenceLSTMOp::inferReturnTypes( MLIRContext*, std::optional, ValueRange operands, - DictionaryAttr attr, OpaqueProperties, RegionRange, + DictionaryAttr attr, OpaqueProperties properties, RegionRange, SmallVectorImpl& inferredReturnTypes) { Value input = operands[0]; auto input_type = input.getType().dyn_cast_or_null(); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index 5872fbc8c953c9..ed3a963bbb7523 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -34,10 +34,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_dialect.h.inc" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_enums.h.inc" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" -#include "tensorflow/lite/schema/schema_generated.h" #define GET_ATTRDEF_CLASSES #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_attrdefs.h.inc" diff --git a/tensorflow/compiler/mlir/lite/json_to_flatbuffer.cc b/tensorflow/compiler/mlir/lite/json_to_flatbuffer.cc index 4a4e7a65cd6cdc..80558576e08a38 100644 --- a/tensorflow/compiler/mlir/lite/json_to_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/json_to_flatbuffer.cc @@ -22,7 +22,6 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/idl.h" // from @flatbuffers #include "flatbuffers/util.h" // from @flatbuffers -#include "tensorflow/lite/schema/schema_generated.h" int main(int argc, char** argv) { // load FlatBuffer schema (.fbs) and JSON from disk diff --git a/tensorflow/compiler/mlir/lite/metrics/BUILD b/tensorflow/compiler/mlir/lite/metrics/BUILD index 6218a2fb30a829..464cd8f33822b7 100644 --- a/tensorflow/compiler/mlir/lite/metrics/BUILD +++ b/tensorflow/compiler/mlir/lite/metrics/BUILD @@ -72,5 +72,6 @@ cc_library( "//tensorflow/lite/python/metrics:converter_error_data_proto_cc", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/lite/metrics/types_util.cc b/tensorflow/compiler/mlir/lite/metrics/types_util.cc index b47347ceb03827..7dd658e54dd12e 100644 --- a/tensorflow/compiler/mlir/lite/metrics/types_util.cc +++ b/tensorflow/compiler/mlir/lite/metrics/types_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/TypeSwitch.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/lite/python/metrics/converter_error_data.pb.h" namespace mlir { @@ -67,8 +68,8 @@ class LocationExtractor : public Location { new_call->set_name(loc.getName().str()); // Add child as the source location. auto child_loc = loc.getChildLoc(); - if (child_loc.isa()) { - auto typed_child_loc = child_loc.dyn_cast(); + if (mlir::isa(child_loc)) { + auto typed_child_loc = mlir::dyn_cast(child_loc); ExtractFileLine(typed_child_loc, new_call->mutable_source()); } }) @@ -83,7 +84,7 @@ class LocationExtractor : public Location { // Skip the first location if it stores information for propagating // op_type metadata. if (num_locs > 0) { - if (auto name_loc = locations[0].dyn_cast()) { + if (auto name_loc = mlir::dyn_cast(locations[0])) { if (name_loc.getName().strref().ends_with(":")) { if (num_locs == 2) { return LocationExtractor(locations[1]).Extract(error_data); diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index 203a06ff721a02..c7f50ba6edf81c 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -95,6 +95,7 @@ cc_library( "//tensorflow/lite/toco:model_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:types_proto_cc", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", @@ -143,10 +144,9 @@ cc_library( ], deps = [ "//tensorflow/compiler/mlir/lite:flatbuffer_import", + "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", - "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TranslateLib", ], diff --git a/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc b/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc index 23b3714e73a758..6591251d9e915b 100644 --- a/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc +++ b/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc @@ -18,25 +18,19 @@ limitations under the License. #include #include -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/InitLLVM.h" +#include "absl/strings/string_view.h" #include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" -#include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" diff --git a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h index e5ce76612f3e01..11a8e28ebed0cc 100644 --- a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h @@ -18,6 +18,8 @@ limitations under the License. #include #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 16e12bbb6da04d..085478db128a71 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -21,14 +21,18 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "llvm/ADT/StringSet.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" @@ -42,9 +46,11 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace tensorflow { @@ -80,7 +86,7 @@ Status HandleInputOutputArraysWithModule( if (!input_attr) { return errors::InvalidArgument("no inputs attribute found"); } - auto input_names = input_attr.cast().getValue(); + auto input_names = mlir::cast(input_attr).getValue(); input_names.split(function_input_names, ",", /*MaxSplit=*/-1, /*KeepEmpty=*/false); const int function_input_names_size = function_input_names.size(); @@ -106,7 +112,7 @@ Status HandleInputOutputArraysWithModule( if (!output_attr) { return errors::InvalidArgument("no outputs attribute found"); } - auto output_names = output_attr.cast().getValue(); + auto output_names = mlir::cast(output_attr).getValue(); output_names.split(function_output_names, ",", /*MaxSplit=*/-1, /*KeepEmpty=*/false); const int function_output_names_size = function_output_names.size(); diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h index 87b5d80d025b47..a9170bdc86a4f5 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h @@ -19,6 +19,8 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index 224db40942ef90..a4512d226939f7 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -43,10 +43,10 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/public/session.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index ef436759985dcb..a57b3585abadb1 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -24,10 +24,14 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc index 3de159a1414429..f6eac3e90ec8bd 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/Passes.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" @@ -80,7 +81,7 @@ LogicalResult QuantizedConstRewrite::matchAndRewrite( } // Is the constant value a type expressed in a way that we support? - if (!value.isa()) { + if (!mlir::isa(value)) { return failure(); } diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc index e99addc5b5f8a5..a51956ce08a239 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc @@ -13,18 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/Passes.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" -using namespace mlir; -using namespace mlir::quantfork; - -namespace { +namespace mlir::quantfork { #define GEN_PASS_DEF_QUANTCONVERTSIMULATEDQUANT #include "tensorflow/compiler/mlir/lite/quantization/ir/Passes.h.inc" @@ -51,7 +59,6 @@ class FakeQuantRewrite : public OpRewritePattern { *hadFailure = true; return failure(); } - return success(); } @@ -66,7 +73,7 @@ class FakeQuantRewrite : public OpRewritePattern { quant::QuantizedType elementType = static_cast(this) - ->convertFakeQuantAttrsToType(op, converter.expressedType); + ->convertFakeQuantAttrsToType(op, converter.expressed_type); if (!elementType) { // Note that the fakeQuantAttrsToType will have emitted the error. @@ -81,7 +88,7 @@ class FakeQuantRewrite : public OpRewritePattern { // this is a forced/hard-coded constraint. auto qbarrier = rewriter.create(op.getLoc(), quantizedType, op.getInputs()); - rewriter.replaceOpWithNewOp(op, converter.inputType, + rewriter.replaceOpWithNewOp(op, converter.input_type, qbarrier.getResult()); return false; @@ -121,9 +128,9 @@ class ConstFakeQuantPerAxisRewrite min.reserve(fqOp.getMin().size()); max.reserve(fqOp.getMax().size()); for (auto m : fqOp.getMin()) - min.push_back(m.cast().getValueAsDouble()); + min.push_back(cast(m).getValueAsDouble()); for (auto m : fqOp.getMax()) - max.push_back(m.cast().getValueAsDouble()); + max.push_back(cast(m).getValueAsDouble()); return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.getNumBits(), fqOp.getAxis(), min, max, fqOp.getNarrowRange(), @@ -131,8 +138,6 @@ class ConstFakeQuantPerAxisRewrite } }; -} // namespace - void ConvertSimulatedQuantPass::runOnOperation() { bool hadFailure = false; auto func = getOperation(); @@ -144,7 +149,8 @@ void ConvertSimulatedQuantPass::runOnOperation() { if (hadFailure) signalPassFailure(); } -std::unique_ptr> -mlir::quantfork::createConvertSimulatedQuantPass() { +std::unique_ptr> createConvertSimulatedQuantPass() { return std::make_unique(); } + +} // namespace mlir::quantfork diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc index d111141958c403..8aa6475b888702 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project using namespace mlir; using namespace mlir::quantfork; @@ -51,20 +52,20 @@ OpFoldResult StorageCastOp::fold(FoldAdaptor) { /// The quantization specification should match the expressed type. static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) { - if (auto typeAttr = quantSpec.dyn_cast()) { + if (auto typeAttr = mlir::dyn_cast(quantSpec)) { Type spec = typeAttr.getValue(); - if (spec.isa()) return false; + if (mlir::isa(spec)) return false; // The spec should be either a quantized type which is compatible to the // expressed type, or a primitive type which is as same as the // (element type of) the expressed type. - if (auto quantizedType = spec.dyn_cast()) + if (auto quantizedType = mlir::dyn_cast(spec)) return quantizedType.isCompatibleExpressedType(expressed); - if (auto tensorType = expressed.dyn_cast()) + if (auto tensorType = mlir::dyn_cast(expressed)) return spec == tensorType.getElementType(); - if (auto vectorType = expressed.dyn_cast()) + if (auto vectorType = mlir::dyn_cast(expressed)) return spec == vectorType.getElementType(); } return false; @@ -99,13 +100,13 @@ LogicalResult QuantizeRegionOp::verify() { } LogicalResult StatisticsOp::verify() { - auto tensorArg = getArg().getType().dyn_cast(); + auto tensorArg = mlir::dyn_cast(getArg().getType()); if (!tensorArg) return emitOpError("arg needs to be tensor type."); // Verify layerStats attribute. { auto layerStatsType = getLayerStats().getShapedType(); - if (!layerStatsType.getElementType().isa()) { + if (!mlir::isa(layerStatsType.getElementType())) { return emitOpError("layerStats must have a floating point element type"); } if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { @@ -122,7 +123,7 @@ LogicalResult StatisticsOp::verify() { std::multiplies()); auto axisStatsType = getAxisStats()->getShapedType(); - if (!axisStatsType.getElementType().isa()) { + if (!mlir::isa(axisStatsType.getElementType())) { return emitOpError("axisStats must have a floating point element type"); } if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 || diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOpsBase.td b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOpsBase.td index f9afdc41db1dac..ed7a16c74d0fb7 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOpsBase.td +++ b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOpsBase.td @@ -27,7 +27,6 @@ include "mlir/IR/OpBase.td" def QuantizationFork_Dialect : Dialect { let name = "quantfork"; let cppNamespace = "::mlir::quantfork"; - let usePropertiesForAttributes = 0; } #endif // QUANT_FORK_BASE diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc b/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc index 919c711272b2c1..2ad06f77de8866 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc @@ -32,8 +32,8 @@ using namespace mlir::quantfork; static Attribute convertPrimitiveValueAttr( Attribute origRealValue, quant::QuantizedType quantizedElementType, const UniformQuantizedValueConverter &converter, Type &outConvertedType) { - if (origRealValue.isa()) { - FloatAttr floatAttr = origRealValue.cast(); + if (mlir::isa(origRealValue)) { + FloatAttr floatAttr = mlir::cast(origRealValue); outConvertedType = quantizedElementType.getStorageType(); return IntegerAttr::get(quantizedElementType.getStorageType(), converter.quantizeFloatToInt(floatAttr.getValue())); @@ -64,11 +64,11 @@ static SparseElementsAttr convertSparseElementsAttr( quant::QuantizedType quantizedElementType, const UniformQuantizedValueConverter &converter) { DenseElementsAttr realDenseAttr = realSparseAttr.getValues(); - if (!realDenseAttr.isa()) { + if (!mlir::isa(realDenseAttr)) { return nullptr; } DenseElementsAttr quantDenseAttr = - convertDenseFPElementsAttr(realDenseAttr.cast(), + convertDenseFPElementsAttr(mlir::cast(realDenseAttr), quantizedElementType, converter); if (!quantDenseAttr) { return nullptr; @@ -76,9 +76,9 @@ static SparseElementsAttr convertSparseElementsAttr( // Cast from an expressed-type-based type to storage-type-based type, // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>). - ShapedType newSparseType = - quantizedElementType.castExpressedToStorageType(realSparseAttr.getType()) - .dyn_cast_or_null(); + ShapedType newSparseType = mlir::dyn_cast_or_null( + quantizedElementType.castExpressedToStorageType( + realSparseAttr.getType())); if (!newSparseType) { return nullptr; } @@ -93,17 +93,19 @@ Attribute mlir::quantfork::quantizeAttrUniform( Attribute realValue, quant::UniformQuantizedType quantizedElementType, const UniformQuantizedValueConverter &converter, Type &outConvertedType) { // Fork to handle different variants of constants supported. - if (realValue.isa()) { + if (mlir::isa(realValue)) { // Dense tensor or vector constant. - auto converted = convertDenseFPElementsAttr( - realValue.cast(), quantizedElementType, converter); + auto converted = + convertDenseFPElementsAttr(mlir::cast(realValue), + quantizedElementType, converter); outConvertedType = converted.getType(); return converted; } - if (realValue.isa()) { + if (mlir::isa(realValue)) { // Sparse tensor or vector constant. - auto converted = convertSparseElementsAttr( - realValue.cast(), quantizedElementType, converter); + auto converted = + convertSparseElementsAttr(mlir::cast(realValue), + quantizedElementType, converter); outConvertedType = converted.getType(); return converted; } @@ -121,13 +123,14 @@ Attribute mlir::quantfork::quantizeAttr( Attribute realValue, quant::QuantizedType quantizedElementType, Type &outConvertedType) { if (auto uniformQuantized = - quantizedElementType.dyn_cast()) { + mlir::dyn_cast(quantizedElementType)) { UniformQuantizedValueConverter converter(uniformQuantized); return quantizeAttrUniform(realValue, uniformQuantized, converter, outConvertedType); } if (auto uniformQuantizedPerAxis = - quantizedElementType.dyn_cast()) { + mlir::dyn_cast( + quantizedElementType)) { UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis); auto converted = converter.convert(realValue); // TODO: why we need this outConvertedType? remove it? diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 66df4f528aa43d..48e8ebe35dde67 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -31,18 +31,19 @@ cc_library( "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tf_tfl_passes", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/core:protos_all_cc", "//tensorflow/lite:framework", "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core/api", - "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], ) @@ -55,24 +56,24 @@ cc_library( "quantize_weights.h", ], deps = [ - ":quantize_model", "//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize", "//tensorflow/compiler/mlir/lite:tf_tfl_passes", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/core:protos_all_cc", "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core/api", - "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", + "@flatbuffers//:runtime_cc", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], ) @@ -112,11 +113,10 @@ tf_cc_binary( ], deps = [ ":quantize_model", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/lite:framework", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/strings", + "//tensorflow/lite/c:c_api_types", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", ], ) @@ -164,12 +164,16 @@ tf_cc_test( ], deps = [ ":quantize_model", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite:string", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/core/api:error_reporter", "//tensorflow/lite/schema:schema_utils", "//tensorflow/lite/tools/optimize:test_util", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -198,15 +202,15 @@ tf_cc_test( ], deps = [ ":quantize_weights", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/schema:schema_utils", "//tensorflow/lite/tools/optimize:test_util", "@com_google_googletest//:gtest", "@flatbuffers", - "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:logging", ], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 19af4a756f9bea..12be81041d66de 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -20,16 +20,20 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" #include "llvm/Support/Debug.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" @@ -37,7 +41,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/core/api/error_reporter.h" namespace mlir { namespace lite { diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index 50b397ba0206d2..665766d700512d 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -20,9 +20,9 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/schema/schema_generated.h" namespace mlir { namespace lite { diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index 02f07e98a1dbca..f1bf4363e797c3 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -27,14 +27,19 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "flatbuffers/flexbuffers.h" // from @flatbuffers -#include "tensorflow/core/lib/io/path.h" +#include "absl/container/flat_hash_set.h" +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" -#include "tensorflow/lite/model.h" -#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/schema/schema_utils.h" +#include "tensorflow/lite/string_type.h" #include "tensorflow/lite/tools/optimize/test_util.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_model_test.cc diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc index d55d41cc7a8e6e..e2581e7c53f7f4 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc @@ -21,26 +21,30 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" -#include "llvm/ADT/SmallVector.h" +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" #include "llvm/Support/Debug.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/stderr_reporter.h" namespace mlir { namespace lite { diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h index 6c94e4c2d10c71..f92b58ffb3b01c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h @@ -22,9 +22,11 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/model.h" -#include "tensorflow/lite/schema/schema_generated.h" namespace mlir { namespace lite { diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc index 57d0fda20ba33a..7056b7a244fc59 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc @@ -21,12 +21,18 @@ limitations under the License. #include #include -#include "llvm/ADT/Twine.h" -#include "tensorflow/core/lib/io/path.h" +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" -#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/schema/schema_utils.h" +#include "tensorflow/lite/stderr_reporter.h" #include "tensorflow/lite/tools/optimize/test_util.h" #include "tsl/platform/logging.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc index 4bf154e892bcdb..73e6140d658c4c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc @@ -15,14 +15,16 @@ limitations under the License. #include +#include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h" -#include "tensorflow/lite/model.h" -#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/stderr_reporter.h" using llvm::cl::opt; diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc index df86c6fa6a5be9..339dfee21495ae 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc @@ -15,10 +15,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h" #include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD index 43142a7a7c52dd..307741acda3439 100644 --- a/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD @@ -45,7 +45,6 @@ tf_cc_test( "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", - "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton_impl", # buildcleaner: keep; prevents undefined reference "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc index 9a25d849ea7c8a..e6284d273e50d0 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc @@ -268,7 +268,7 @@ Value SetNoFallbackAttr(PatternRewriter &rewriter, Value val) { // Returns true if the attr is a float attribute and be equal to value. static bool FloatValueEquals(const Attribute &attr, double value) { - auto fp_attr = attr.dyn_cast_or_null(); + auto fp_attr = mlir::dyn_cast_or_null(attr); if (fp_attr == nullptr) return false; if (fp_attr.isSplat()) { @@ -281,7 +281,7 @@ static bool FloatValueEquals(const Attribute &attr, double value) { // Returns true if the rank of the value equals to the given rank. bool RankEquals(Value value, int rank) { - auto rank_type = value.getType().template dyn_cast(); + auto rank_type = mlir::dyn_cast(value.getType()); return (rank_type && rank_type.getRank() == rank); } diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_default.mlir b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_default.mlir index 9c6d9b8aa8059b..b8a9f325b11077 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_default.mlir +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_default.mlir @@ -12,7 +12,7 @@ func.func @bias_add(%arg0: tensor<1x10x10x32xf32>, %arg1: tensor<32xf32>) -> ten func.func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> func.return %0: tensor<1xf32> -// CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%arg0, %arg1) {custom_code = "FlexAdd", custom_option = #tfl} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%arg0, %arg1) <{custom_code = "FlexAdd", custom_option = #tfl}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: return %[[CUSTOM_0]] : tensor<1xf32> } @@ -20,7 +20,7 @@ func.func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { func.func @softmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> func.return %0 : tensor<8x16xf32> -// CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%arg0) {custom_code = "FlexSoftmax", custom_option = #tfl} : (tensor<8x16xf32>) -> tensor<8x16xf32> +// CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%arg0) <{custom_code = "FlexSoftmax", custom_option = #tfl}> : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK: return %[[CUSTOM_0]] : tensor<8x16xf32> } @@ -52,7 +52,7 @@ func.func @conv2d_backprop_input_with_sub(%arg0: tensor<4xi32>, %arg1: tensor<3x func.func @depth_to_space(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> { %0 = "tf.DepthToSpace"(%arg0) {block_size = 2: i64, data_format = "NHWC"}: (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> func.return %0 : tensor<1x2x2x1xf32> -// CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%arg0) {custom_code = "FlexDepthToSpace", custom_option = #tfl} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> +// CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%arg0) <{custom_code = "FlexDepthToSpace", custom_option = #tfl}> : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> // CHECK: return %[[CUSTOM_0]] : tensor<1x2x2x1xf32> } @@ -60,7 +60,7 @@ func.func @depth_to_space(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> { func.func @floor_mod(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tensor<5xf32> { %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> func.return %0 : tensor<5xf32> -// CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%arg0, %arg1) {custom_code = "FlexFloorMod", custom_option = #tfl} : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> +// CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%arg0, %arg1) <{custom_code = "FlexFloorMod", custom_option = #tfl}> : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> // CHECK: return %[[CUSTOM_0]] : tensor<5xf32> } @@ -82,7 +82,7 @@ func.func @identity(%arg0: tensor<2xf32>) -> tensor<*xf32> { func.return %1 : tensor<*xf32> // CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<1.000000e-03> : tensor}> {device = ""} : () -> tensor // CHECK: %[[IDENTITY_0:.*]] = "tf.Identity"(%arg0) {device = ""} : (tensor<2xf32>) -> tensor<*xf32> -// CHECK: %[[ADDV2_0:.*]] = "tfl.custom"(%0, %cst) {custom_code = "FlexAddV2", custom_option = #tfl} : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK: %[[ADDV2_0:.*]] = "tfl.custom"(%0, %cst) <{custom_code = "FlexAddV2", custom_option = #tfl}> : (tensor<*xf32>, tensor) -> tensor<*xf32> // CHECK: return %[[ADDV2_0]] : tensor<*xf32> } @@ -148,7 +148,7 @@ func.func @conv_with_relu1_invalid_pattern(%arg0: tensor<1x3x4x3xf32>) -> (tenso // CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<[-1.000000e+00, -3.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32> // CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() <{value = dense<[1.000000e+00, 3.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32> // CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x3x4x3xf32>, tensor<1x1x3x2xf32>) -> tensor<1x3x4x2xf32> -// CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%[[CONV2D_0]], %[[CONST_2]]) {custom_code = "FlexMinimum", custom_option = #tfl} : (tensor<1x3x4x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32> -// CHECK: %[[CUSTOM_1:.*]] = "tfl.custom"(%[[CUSTOM_0]], %[[CONST_1]]) {custom_code = "FlexMaximum", custom_option = #tfl} : (tensor<1x3x4x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%[[CONV2D_0]], %[[CONST_2]]) <{custom_code = "FlexMinimum", custom_option = #tfl}> : (tensor<1x3x4x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CUSTOM_1:.*]] = "tfl.custom"(%[[CUSTOM_0]], %[[CONST_1]]) <{custom_code = "FlexMaximum", custom_option = #tfl}> : (tensor<1x3x4x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32> // CHECK: return %[[CUSTOM_1]] : tensor<1x3x4x2xf32> } diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_legacy.mlir b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_legacy.mlir index 5835d7d107cef5..258a006ee37fc6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_legacy.mlir +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_legacy.mlir @@ -52,7 +52,7 @@ func.func @conv2d_backprop_input_with_sub(%arg0: tensor<4xi32>, %arg1: tensor<3x func.func @depth_to_space(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> { %0 = "tf.DepthToSpace"(%arg0) {block_size = 2: i64, data_format = "NHWC"}: (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> func.return %0 : tensor<1x2x2x1xf32> -// CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%arg0) {custom_code = "FlexDepthToSpace", custom_option = #tfl} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> +// CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%arg0) <{custom_code = "FlexDepthToSpace", custom_option = #tfl}> : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> // CHECK: return %[[CUSTOM_0]] : tensor<1x2x2x1xf32> } @@ -60,7 +60,7 @@ func.func @depth_to_space(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> { func.func @floor_mod(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tensor<5xf32> { %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> func.return %0 : tensor<5xf32> -// CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%arg0, %arg1) {custom_code = "FlexFloorMod", custom_option = #tfl} : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> +// CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%arg0, %arg1) <{custom_code = "FlexFloorMod", custom_option = #tfl}> : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> // CHECK: return %[[CUSTOM_0]] : tensor<5xf32> } diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc index 55790c40509946..b4015181886788 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc @@ -133,7 +133,7 @@ struct InsertQuantOpsAfterTFFakeQuantOp if (PerAxis) { // This is a special case that the quant_dim is the last dimensions // according to the tf.FakeQuantWithMinMaxPerChannel. - quant_dim = res.getType().template cast().getRank() - 1; + quant_dim = mlir::cast(res.getType()).getRank() - 1; } // Use the min/max from the operands and the num_bits and narrow_range // attribute to create the quantization parameter for the new quantize op. diff --git a/tensorflow/compiler/mlir/lite/quantization/tests/import_quant_stats.mlir b/tensorflow/compiler/mlir/lite/quantization/tests/import_quant_stats.mlir index 1260089c0f264a..5bcb6837f14d81 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tests/import_quant_stats.mlir +++ b/tensorflow/compiler/mlir/lite/quantization/tests/import_quant_stats.mlir @@ -18,8 +18,8 @@ func.func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor) -> (tensor func.return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32> // CHECK-NEXT: %[[split:.*]]:2 = "tfl.split" -// CHECK-NEXT: %[[stats1:.*]] = "quantfork.stats"(%[[split]]#0) {layerStats = dense<[-1.000000e+00, 1.000000e+00]> -// CHECK-NEXT: %[[stats2:.*]] = "quantfork.stats"(%[[split]]#1) {layerStats = dense<[-1.000000e+00, 1.000000e+00]> +// CHECK-NEXT: %[[stats1:.*]] = "quantfork.stats"(%[[split]]#0) <{layerStats = dense<[-1.000000e+00, 1.000000e+00]> +// CHECK-NEXT: %[[stats2:.*]] = "quantfork.stats"(%[[split]]#1) <{layerStats = dense<[-1.000000e+00, 1.000000e+00]> // CHECK-NEXT: return %[[stats1]], %[[stats2]] : tensor<2xf32>, tensor<2xf32> } @@ -30,7 +30,7 @@ func.func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor) -> (t func.return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32> // CHECK-NEXT: %[[split:.*]]:2 = "tfl.split" -// CHECK-NEXT: %[[stats1:.*]] = "quantfork.stats"(%[[split]]#0) {layerStats = dense<[-2.000000e+00, 2.000000e+00]> +// CHECK-NEXT: %[[stats1:.*]] = "quantfork.stats"(%[[split]]#0) <{layerStats = dense<[-2.000000e+00, 2.000000e+00]> // CHECK-NEXT: return %[[stats1]], %[[split]]#1 : tensor<2xf32>, tensor<2xf32> } @@ -41,7 +41,7 @@ func.func @import_stats_name_regex(%arg0: tensor<4xf32>, %cst: tensor) -> ( func.return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32> // CHECK-NEXT: %[[split:.*]]:2 = "tfl.split" -// CHECK-NEXT: %[[stats1:.*]] = "quantfork.stats"(%[[split]]#0) {layerStats = dense<[-3.000000e+00, 3.000000e+00]> -// CHECK-NEXT: %[[stats2:.*]] = "quantfork.stats"(%[[split]]#1) {layerStats = dense<[-3.000000e+00, 3.000000e+00]> +// CHECK-NEXT: %[[stats1:.*]] = "quantfork.stats"(%[[split]]#0) <{layerStats = dense<[-3.000000e+00, 3.000000e+00]> +// CHECK-NEXT: %[[stats2:.*]] = "quantfork.stats"(%[[split]]#1) <{layerStats = dense<[-3.000000e+00, 3.000000e+00]> // CHECK-NEXT: return %[[stats1]], %[[stats2]] : tensor<2xf32>, tensor<2xf32> } diff --git a/tensorflow/compiler/mlir/lite/schema/BUILD b/tensorflow/compiler/mlir/lite/schema/BUILD new file mode 100644 index 00000000000000..34b799a9738741 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/schema/BUILD @@ -0,0 +1,46 @@ +load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +exports_files( + srcs = ["schema.fbs"], +) + +flatbuffer_cc_library( + name = "schema_fbs", + srcs = ["schema.fbs"], + compatible_with = get_compatible_with_portable(), +) + +# Generic schema for flatbuffer converter (but with mutable makes bigger). +flatbuffer_cc_library( + name = "schema_fbs_with_mutable", + srcs = ["schema.fbs"], + compatible_with = get_compatible_with_portable(), + flatc_args = [ + "--gen-mutable", + "--gen-object-api", + ], + out_prefix = "mutable/", +) + +# Generic schema for inference on device (but with reflections makes bigger). +flatbuffer_cc_library( + name = "schema_fbs_with_reflection", + srcs = ["schema.fbs"], + compatible_with = get_compatible_with_portable(), + flatc_args = [ + "--reflect-types", + "--reflect-names", + "--no-union-value-namespacing", + "--gen-object-api", + ], + out_prefix = "reflection/", +) diff --git a/tensorflow/compiler/mlir/lite/schema/README.md b/tensorflow/compiler/mlir/lite/schema/README.md new file mode 100644 index 00000000000000..369027689e09e2 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/schema/README.md @@ -0,0 +1,2 @@ +This directory contains schema related files and targets that are used by both +the TFL converter (tf/compiler/mlir/lite/) and the runtime (tf/lite/). \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/schema/schema.fbs b/tensorflow/compiler/mlir/lite/schema/schema.fbs new file mode 100644 index 00000000000000..7ab78be26737ee --- /dev/null +++ b/tensorflow/compiler/mlir/lite/schema/schema.fbs @@ -0,0 +1,1653 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. +// Version 2: Rename operators to conform to NN API. +// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. +// Version 3a: Add new builtin op code field. Has backward compatibility with +// version 3. +// Version 3b: Rename fields in SignatureDef. Has backward compatibility with +// version 3 and 3a. +// Version 3c: Move constant tensor buffers & custom op buffers outside from +// Flatbuffers. Has backward compatibility with version 3, 3a and +// 3b. + +namespace tflite; + +// This corresponds to the version. +file_identifier "TFL3"; +// File extension of any written files. +file_extension "tflite"; + +// IMPORTANT: All new members of tables, enums and unions must be added at the +// end to ensure backwards compatibility. + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, + BOOL = 6, + INT16 = 7, + COMPLEX64 = 8, + INT8 = 9, + FLOAT64 = 10, + COMPLEX128 = 11, + UINT64 = 12, + // Experimental: Resource and variant types are experimental, that are subject + // to change. Do not implement custom kernels using resource & variant types + // now. + RESOURCE = 13, + VARIANT = 14, + UINT32 = 15, + UINT16 = 16, + INT4 = 17, + BFLOAT16 = 18, +} + +// Custom quantization parameters for experimenting with new quantization +// techniques. +table CustomQuantization { + custom:[ubyte] (force_align: 16); +} + +// Represents a specific quantization technique's parameters. +union QuantizationDetails { + CustomQuantization, +} + +// Parameters for converting a quantized tensor back to float. +table QuantizationParameters { + // These four parameters are the asymmetric linear quantization parameters. + // Given a quantized value q, the corresponding float value f should be: + // f = scale * (q - zero_point) + // For other quantization types, the QuantizationDetails below is used. + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; // For dequantizing the tensor's values. + zero_point:[long]; + + // If this is not none, the other quantization parameters (i.e. min, max, + // scale, zero_point fields above) are ignored and the value of the + // QuantizationDetails union should be used. + details:QuantizationDetails; + + // Specifies the dimension of the Tensor's shape that the scales and + // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] + // with quantization params: + // scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1 + // will be quantized across the second dimension of t. + // t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 + // t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 + // t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 + quantized_dimension:int; +} + +// Sparse tensors. +// We use a modification of the TACO format. +// Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf +// +// To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1), +// potentially with a k-dimensional block (0 <= k <= n) with dims +// (dn, ..., dn+k-1), the format needs to specify: +// 1. In what order to traverse these dimensions. For example, to store a 2-D +// matrix in row major order, the traversal order would be (d0, d1), +// whereas to store it in column major order, the traversal order would be +// (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order +// could be (d0, d1, d2, d3). +// 2. How each block dimension in (dn, ..., dn+k-1) maps to the original +// tensor dimension in (d0, ..., dn-1). +// 3. In the traversal order defined above, the format (dense vs. sparse) and +// index metadata for each dimension. For a dense dimension, this is just +// the size of that dimension. For a sparse dimension, it's the same as +// the compressed index defined in the Compressed Sparse Row (CSR) format. +// (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html) + +// The storage type for a dimension. Currently we support: +// 1. DENSE: each coordinate in this dimension is stored implicitly. +// 2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The +// compression technique is the same what CSR uses. +// More types like a sparse dimension with a different compression technique +// could be added to the list in the future. +enum DimensionType : byte { + DENSE = 0, + SPARSE_CSR = 1, +} + +table Int32Vector { + values:[int]; +} + +table Uint16Vector { + values:[ushort] (force_align: 4); +} + +table Uint8Vector { + values:[ubyte] (force_align: 4); +} + +// Variable-typed buffer to store the index metadata for a sparse dimension. +// The widest type is Int32 instead of UInt32 because tensor's shape is a int32 +// vector. We don't want the per-dimensional index to overflow that range. +union SparseIndexVector { + Int32Vector, + Uint16Vector, + Uint8Vector +} + +table DimensionMetadata { + // Whether a dimension is dense or sparse. + format:DimensionType; + // Index metadata used for a dimension. + // - If format is DimensionType.DENSE then we use the dense_size field to + // store the size of that dimension. Each index in that dimension is + // stored implicitly. + // - If format is DimensionType.SPARSE_CSR then we use array_segments and + // array_indices to encode that dimension. array_segments represents how + // to segment the indices array, each segment corresponds to one element + // in the previous dimension. array_indices represents the index of the + // non-zero elements within this dimension (as those in the CSR matrix + // format, where the first array is row pointers and the second array is + // column indices). + dense_size:int; + array_segments:SparseIndexVector; + array_indices:SparseIndexVector; +} + +// Parameters to encode a sparse TfLite tensor. +table SparsityParameters { + // The traversal order of the dimensions defined in the `shape` field of the + // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1, + // ..., dn-1), + // - if not block sparse, the traversal_order is just a permutation of (d0, + // ..., dn-1). For example, a 2-D matrix stored in row-major order would + // have traversal_order = (d0, d1). + // - if block sparse with a k-dimensional block (0 <= k <= n), the + // traversal_order has n + k elements. The first n elements are still a + // permutation of (d0, ..., dn-1). The lask k elements are a permutation + // of (dn, ..., dn+k-1), defining how to traverse a block internally. For + // example, a 2-D matrix with 2-D blocks, both stored in row-major order + // would have traversal_order = (d0, d1, d2, d3). + traversal_order:[int]; + // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n), + // stores how a block dimension in (dn, ..., dn+k-1) maps to the original + // tensor dimension in (d0, ..., dn). + // It's stored in the order of (dn, ..., dn+k-1). + // If not block-sparse, this field is NULL. + block_map:[int]; + // In the traversal order defined above, the metadata needed for + // each dimension to locate the non-zero values in the original dense tensor. + // The size of the dim_metadata array = the size of the traversal_order array + // = n + k. + dim_metadata:[DimensionMetadata]; +} + +// The nested tensor type for VARIANT type. +table VariantSubType { + // The tensor shape. + shape:[int]; + type:TensorType; + // If false, the rank or the number of tensor dimensions is unknown. + // If false, "shape" must be []. + has_rank: bool = false; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, height, width, number of channels] (That's + // Tensorflow's NHWC). + shape:[int]; + type:TensorType; + // An index that refers to the buffers table at the root of the model. Or, + // if there is no data buffer associated (i.e. intermediate results), then + // this is 0 (which refers to an always existent empty buffer). + // + // The data_buffer itself is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k]. + buffer:uint; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. + + is_variable:bool = false; + + // Parameters to encode a sparse tensor. See the example in + // tensorflow/lite/testdata/sparse_tensor.json. + sparsity:SparsityParameters; // Optional. + + // Encodes `shape` with unknown dimensions. Unknown dimensions are + // represented with -1. + shape_signature:[int]; // Optional. + + // This field is added to distinguish between scalars and tensors of unknown + // ranks (both of which shape is []). + // For scalars (rank = 0), shape = [] and has_rank = true. + // For tensors with known rank (rank > 0) and shape, shape = [...] and + // has_rank = true. + // For tensors with unknown rank and shape, shape = [] and has_rank = false. + has_rank: bool = false; + + // The nested Tensor types for VARIANT type. This is always empty for + // non-VARIANT types. This is optional because the nested type can be omitted. + // Currently only 1 subtype is supported. The field is defined as an array for + // flexibility of supporting multiple subtypes in the future. + variant_tensors:[VariantSubType]; +} + +// A list of builtin operators. Builtin operators are slightly faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +// LINT.IfChange +enum BuiltinOperator : int32 { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + DEPTH_TO_SPACE = 5, + DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + MUL = 18, + RELU = 19, + // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed + // since different model developers use RELU1 in different ways. Never + // create another op called RELU1. + RELU_N1_TO_1 = 20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + EMBEDDING_LOOKUP_SPARSE = 33, + PAD = 34, + UNIDIRECTIONAL_SEQUENCE_RNN = 35, + GATHER = 36, + BATCH_TO_SPACE_ND = 37, + SPACE_TO_BATCH_ND = 38, + TRANSPOSE = 39, + MEAN = 40, + SUB = 41, + DIV = 42, + SQUEEZE = 43, + UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + STRIDED_SLICE = 45, + BIDIRECTIONAL_SEQUENCE_RNN = 46, + EXP = 47, + TOPK_V2 = 48, + SPLIT = 49, + LOG_SOFTMAX = 50, + // DELEGATE is a special op type for the operations which are delegated to + // other backends. + // WARNING: Experimental interface, subject to change + DELEGATE = 51, + BIDIRECTIONAL_SEQUENCE_LSTM = 52, + CAST = 53, + PRELU = 54, + MAXIMUM = 55, + ARG_MAX = 56, + MINIMUM = 57, + LESS = 58, + NEG = 59, + PADV2 = 60, + GREATER = 61, + GREATER_EQUAL = 62, + LESS_EQUAL = 63, + SELECT = 64, + SLICE = 65, + SIN = 66, + TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, + TILE = 69, + EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, + LOG = 73, + SUM = 74, + SQRT = 75, + RSQRT = 76, + SHAPE = 77, + POW = 78, + ARG_MIN = 79, + FAKE_QUANT = 80, + REDUCE_PROD = 81, + REDUCE_MAX = 82, + PACK = 83, + LOGICAL_OR = 84, + ONE_HOT = 85, + LOGICAL_AND = 86, + LOGICAL_NOT = 87, + UNPACK = 88, + REDUCE_MIN = 89, + FLOOR_DIV = 90, + REDUCE_ANY = 91, + SQUARE = 92, + ZEROS_LIKE = 93, + FILL = 94, + FLOOR_MOD = 95, + RANGE = 96, + RESIZE_NEAREST_NEIGHBOR = 97, + LEAKY_RELU = 98, + SQUARED_DIFFERENCE = 99, + MIRROR_PAD = 100, + ABS = 101, + SPLIT_V = 102, + UNIQUE = 103, + CEIL = 104, + REVERSE_V2 = 105, + ADD_N = 106, + GATHER_ND = 107, + COS = 108, + WHERE = 109, + RANK = 110, + ELU = 111, + REVERSE_SEQUENCE = 112, + MATRIX_DIAG = 113, + QUANTIZE = 114, + MATRIX_SET_DIAG = 115, + ROUND = 116, + HARD_SWISH = 117, + IF = 118, + WHILE = 119, + NON_MAX_SUPPRESSION_V4 = 120, + NON_MAX_SUPPRESSION_V5 = 121, + SCATTER_ND = 122, + SELECT_V2 = 123, + DENSIFY = 124, + SEGMENT_SUM = 125, + BATCH_MATMUL = 126, + PLACEHOLDER_FOR_GREATER_OP_CODES = 127, + CUMSUM = 128, + CALL_ONCE = 129, + BROADCAST_TO = 130, + RFFT2D = 131, + CONV_3D = 132, + IMAG=133, + REAL=134, + COMPLEX_ABS=135, + HASHTABLE = 136, + HASHTABLE_FIND = 137, + HASHTABLE_IMPORT = 138, + HASHTABLE_SIZE = 139, + REDUCE_ALL = 140, + CONV_3D_TRANSPOSE = 141, + VAR_HANDLE = 142, + READ_VARIABLE = 143, + ASSIGN_VARIABLE = 144, + BROADCAST_ARGS = 145, + RANDOM_STANDARD_NORMAL = 146, + BUCKETIZE = 147, + RANDOM_UNIFORM = 148, + MULTINOMIAL = 149, + GELU = 150, + DYNAMIC_UPDATE_SLICE = 151, + RELU_0_TO_1 = 152, + UNSORTED_SEGMENT_PROD = 153, + UNSORTED_SEGMENT_MAX = 154, + UNSORTED_SEGMENT_SUM = 155, + ATAN2 = 156, + UNSORTED_SEGMENT_MIN = 157, + SIGN = 158, + BITCAST = 159, + BITWISE_XOR = 160, + RIGHT_SHIFT = 161, + // All Operators start with STABLEHLO_ prefixes are subject to change + // Many of the ops below can not be executed by TFlite runtime + STABLEHLO_LOGISTIC = 162, // WARNING: Do not have runtime support + STABLEHLO_ADD = 163, + STABLEHLO_DIVIDE = 164, // WARNING: No runtime support yet + STABLEHLO_MULTIPLY = 165, + STABLEHLO_MAXIMUM = 166, + STABLEHLO_RESHAPE = 167, // WARNING: No runtime support yet + STABLEHLO_CLAMP = 168, // WARNING: No runtime support + STABLEHLO_CONCATENATE = 169, // WARNING: No runtime support + STABLEHLO_BROADCAST_IN_DIM = 170, // WARNING: No runtime support + STABLEHLO_CONVOLUTION = 171, // WARNING: No runtime support + STABLEHLO_SLICE = 172, // WARNING: No runtime support + STABLEHLO_CUSTOM_CALL = 173, // WARNING: No runtime support + STABLEHLO_REDUCE = 174, // WARNING: No runtime support + STABLEHLO_ABS = 175, // WARNING: No runtime support + STABLEHLO_AND = 176, // WARNING: No runtime support + STABLEHLO_COSINE = 177, // WARNING: No runtime support + STABLEHLO_EXPONENTIAL = 178, // WARNING: No runtime support + STABLEHLO_FLOOR = 179, // WARNING: No runtime support + STABLEHLO_LOG = 180, // WARNING: No runtime support + STABLEHLO_MINIMUM = 181, + STABLEHLO_NEGATE = 182, // WARNING: No runtime support + STABLEHLO_OR = 183, // WARNING: No runtime support + STABLEHLO_POWER = 184, // WARNING: No runtime support + STABLEHLO_REMAINDER = 185, // WARNING: No runtime support + STABLEHLO_RSQRT = 186, // WARNING: No runtime support + STABLEHLO_SELECT = 187, // WARNING: No runtime support + STABLEHLO_SUBTRACT = 188, // WARNING: No runtime support + STABLEHLO_TANH = 189, // WARNING: No runtime support + STABLEHLO_SCATTER = 190, + STABLEHLO_COMPARE = 191, // WARNING: No runtime support + STABLEHLO_CONVERT = 192, // WARNING: No runtime support + STABLEHLO_DYNAMIC_SLICE = 193, // WARNING: No runtime support + STABLEHLO_DYNAMIC_UPDATE_SLICE = 194, // WARNING: No runtime support + STABLEHLO_PAD = 195, + STABLEHLO_IOTA = 196, // WARNING: No runtime support + STABLEHLO_DOT_GENERAL = 197, // WARNING: No runtime support + STABLEHLO_REDUCE_WINDOW = 198, + STABLEHLO_SORT = 199, // WARNING: No runtime support + STABLEHLO_WHILE = 200, // WARNING: No runtime support + STABLEHLO_GATHER = 201, + STABLEHLO_TRANSPOSE = 202, // WARNING: No runtime support + DILATE = 203, + STABLEHLO_RNG_BIT_GENERATOR = 204, + REDUCE_WINDOW = 205 (deprecated), + STABLEHLO_COMPOSITE = 206, // WARNING: No runtime support +} +// LINT.ThenChange(nnapi_linter/linter.proto) + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, + EmbeddingLookupSparseOptions, + MulOptions, + PadOptions, + GatherOptions, + BatchToSpaceNDOptions, + SpaceToBatchNDOptions, + TransposeOptions, + ReducerOptions, + SubOptions, + DivOptions, + SqueezeOptions, + SequenceRNNOptions, + StridedSliceOptions, + ExpOptions, + TopKV2Options, + SplitOptions, + LogSoftmaxOptions, + CastOptions, + DequantizeOptions, + MaximumMinimumOptions, + ArgMaxOptions, + LessOptions, + NegOptions, + PadV2Options, + GreaterOptions, + GreaterEqualOptions, + LessEqualOptions, + SelectOptions, + SliceOptions, + TransposeConvOptions, + SparseToDenseOptions, + TileOptions, + ExpandDimsOptions, + EqualOptions, + NotEqualOptions, + ShapeOptions, + PowOptions, + ArgMinOptions, + FakeQuantOptions, + PackOptions, + LogicalOrOptions, + OneHotOptions, + LogicalAndOptions, + LogicalNotOptions, + UnpackOptions, + FloorDivOptions, + SquareOptions, + ZerosLikeOptions, + FillOptions, + BidirectionalSequenceLSTMOptions, + BidirectionalSequenceRNNOptions, + UnidirectionalSequenceLSTMOptions, + FloorModOptions, + RangeOptions, + ResizeNearestNeighborOptions, + LeakyReluOptions, + SquaredDifferenceOptions, + MirrorPadOptions, + AbsOptions, + SplitVOptions, + UniqueOptions, + ReverseV2Options, + AddNOptions, + GatherNdOptions, + CosOptions, + WhereOptions, + RankOptions, + ReverseSequenceOptions, + MatrixDiagOptions, + QuantizeOptions, + MatrixSetDiagOptions, + HardSwishOptions, + IfOptions, + WhileOptions, + DepthToSpaceOptions, + NonMaxSuppressionV4Options, + NonMaxSuppressionV5Options, + ScatterNdOptions, + SelectV2Options, + DensifyOptions, + SegmentSumOptions, + BatchMatMulOptions, + CumsumOptions, + CallOnceOptions, + BroadcastToOptions, + Rfft2dOptions, + Conv3DOptions, + HashtableOptions, + HashtableFindOptions, + HashtableImportOptions, + HashtableSizeOptions, + VarHandleOptions, + ReadVariableOptions, + AssignVariableOptions, + RandomOptions, + BucketizeOptions, + GeluOptions, + DynamicUpdateSliceOptions, + UnsortedSegmentProdOptions, + UnsortedSegmentMaxOptions, + UnsortedSegmentMinOptions, + UnsortedSegmentSumOptions, + ATan2Options, + SignOptions, + BitcastOptions, + BitwiseXorOptions, + RightShiftOptions, + // DO NOT add new options this union, will cause failure in Java api + // generation otherwise + // Add new builtin options into builtin options 2 instead +} + +union BuiltinOptions2{ + StablehloConcatenateOptions, + StablehloBroadcastInDimOptions, + StablehloSliceOptions, + StablehloConvolutionOptions, + StablehloCustomCallOptions, + StablehloReduceOptions, + StablehloScatterOptions, + StablehloCompareOptions, + StablehloDynamicSliceOptions, + StablehloPadOptions, + StablehloIotaOptions, + StablehloDotGeneralOptions, + StablehloReduceWindowOptions, + StablehloSortOptions, + StablehloWhileOptions, + StablehloGatherOptions, + StablehloTransposeOptions, + DilateOptions, + StablehloRngBitGeneratorOptions, + ReduceWindowOptions (deprecated), + StableHLOCompositeOptions, +} + +table StablehloGatherOptions{ + offset_dims : [long]; + collapsed_slice_dims : [long]; + start_index_map : [long]; + index_vector_dim : long; + slice_sizes : [long]; + indices_are_sorted : bool; +} + +table StablehloTransposeOptions{ + permutation : [long]; +} + +enum StablehloPrecisionConfig : uint { + DEFAULT, + HIGH, + HIGHEST, +} + +table StablehloDotGeneralOptions{ + lhs_batching_dimensions : [long]; + rhs_batching_dimensions : [long]; + lhs_contracting_dimensions : [long]; + rhs_contracting_dimensions : [long]; + precision_config : [StablehloPrecisionConfig]; +} + +table StablehloReduceWindowOptions{ + window_dimensions : [long]; + window_strides : [long]; + base_dilations : [long]; + window_dilations : [long]; + padding : [long]; + body_subgraph_index : int; +} + +table StablehloWhileOptions{ + cond_subgraph_index : int; + body_subgraph_index : int; +} + +table StablehloSortOptions{ + dimension : long; + is_stable : bool; + comparator_subgraph_index : int; +} + +table StablehloConcatenateOptions { + dimension : long; +} + +table StablehloBroadcastInDimOptions{ + broadcast_dimensions : [long]; +} + +enum StablehloComparisonDirection : uint { + STABLEHLO_COMPARISON_DIRECTION_EQ, + STABLEHLO_COMPARISON_DIRECTION_NE, + STABLEHLO_COMPARISON_DIRECTION_GE, + STABLEHLO_COMPARISON_DIRECTION_GT, + STABLEHLO_COMPARISON_DIRECTION_LE, + STABLEHLO_COMPARISON_DIRECTION_LT, + +} + +enum StablehloComparisonType : uint { + STABLEHLO_COMPARISON_TYPE_NOTYPE, + STABLEHLO_COMPARISON_TYPE_FLOAT, + STABLEHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER, + STABLEHLO_COMPARISON_TYPE_SIGNED, + STABLEHLO_COMPARISON_TYPE_UNSIGNED, +} + +table StablehloCompareOptions{ + comparison_direction : StablehloComparisonDirection; + compare_type : StablehloComparisonType; +} + +table StablehloDynamicSliceOptions{ + slice_sizes : [long]; +} + +table StablehloPadOptions{ + edge_padding_low : [long]; + edge_padding_high : [long]; + interior_padding : [long]; +} + +table StablehloIotaOptions{ + iota_dimension : long; +} + +table StablehloCustomCallOptions { + call_target_name : string; + has_side_effect : bool; + backend_config: string; + api_version : int; // will be decprecated + called_computations: [int]; // should point to subgraphs of the computations + custom_attributes : [ubyte]; +} + +table StablehloReduceOptions { + dimensions : [long]; + body_subgraph_index : int; +} + +table StablehloSliceOptions{ + start_indices : [long]; + limit_indices : [long]; + strides : [long]; +} + +table StablehloConvolutionOptions{ + window_strides : [long]; + padding : [long]; + lhs_dilation : [long]; + rhs_dilation : [long]; + window_reversal : [bool]; + input_batch_dimension : long; + input_feature_dimension : long; + input_spatial_dimensions : [long]; + kernel_input_feature_dimension : long; + kernel_output_feature_dimension : long; + kernel_spatial_dimensions : [long]; + output_batch_dimension : long; + output_feature_dimension : long; + output_spatial_dimensions : [long]; + feature_group_count : long; + batch_group_count : long; + precision_config : [StablehloPrecisionConfig]; +} + +table StablehloScatterOptions { + indices_are_sorted: bool; + update_window_dims: [long]; + inserted_window_dims: [long]; + scatter_dims_to_operand_dims: [long]; + index_vector_dim: long; + unique_indices: bool; + update_computation_subgraph_index: int; +} + +enum RngAlgorithm : byte { + // An algorithm auto-selected by the system according to device type. + DEFAULT = 0, + // The Philox algorithm, as described in paper + // ['Parallel Random Numbers: As Easy as 1, 2, 3'] + // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + PHILOX = 1, + // The ThreeFry algorithm, as described in paper + // ['Parallel Random Numbers: As Easy as 1, 2, 3'] + // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + THREEFRY = 2, +} + +table StablehloRngBitGeneratorOptions { + algorithm:RngAlgorithm; +} + +// LINT.IfChange +enum Padding : byte { SAME, VALID } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +// LINT.IfChange +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU_N1_TO_1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; + // Parameters for Conv2D version 8 or above. + // When set, quantized_bias_type defines the dtype for both bias and accumulator. + quantized_bias_type: TensorType; +} + +// Options for both Conv3D and Conv3DTranspose. +table Conv3DOptions { + padding:Padding; + stride_d:int; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + dilation_d_factor:int = 1; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + // Parameters for DepthwiseConv version 1 or above. + padding:Padding; + stride_w:int; + stride_h:int; + // `depth_multiplier` is redundant. It's used by CPU kernels in + // TensorFlow 2.0 or below, but ignored in versions above. + // See comments in lite/c/builtin_op_data.h for more details. + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; + // Parameters for DepthwiseConv version 2 or above. + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; + // For weights-only quantization, use asymmetric quantization for non + // constant inputs at evaluation time. + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow dynamic_rnn with RNNCell. +table SequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. +table BidirectionalSequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + merge_outputs: bool; + asymmetric_quantize_inputs:bool; +} + +// LINT.IfChange +enum FullyConnectedOptionsWeightsFormat: byte { + DEFAULT = 0, + SHUFFLED4x16INT8 = 1, +} +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + // Parameters for FullyConnected version 1 or above. + fused_activation_function:ActivationFunctionType; + + // Parameters for FullyConnected version 2 or above. + weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; + + // Parameters for FullyConnected version 5 or above. + // If set to true, then the number of dimension is preserved. Furthermore, + // all but the last dimension of the input and output shapes will be equal. + keep_num_dims: bool; + + // Parameters for FullyConnected version 7 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; + + // Parameters for FullyConnected version 11 or above. + // When set, quantized_bias_type defines the dtype for both bias and accumulator. + quantized_bias_type: TensorType; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; + // Parameters supported by version 3. + pot_scale_int16:bool = true; +} + +table MulOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + // This field is currently ignored in the L2 Norm Op. + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +// LINT.IfChange +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + // Parameters for LSTM version 1 or above. + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; + + // Parameters for LSTM version 4 or above. + asymmetric_quantize_inputs: bool; +} + +// An implementation of TensorFlow dynamic_rnn with LSTMCell. +table UnidirectionalSequenceLSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true then first dimension is sequence, otherwise batch. + time_major:bool; + + // Parameter for Unidirectional Sequence LSTM version 3. + asymmetric_quantize_inputs:bool; + + // Parameter for unidirectional sequence RNN version 4. + diagonal_recurrent_tensors:bool; +} + +table BidirectionalSequenceLSTMOptions { + // Parameters supported by version 1: + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true, store the outputs of both directions into the first output. + merge_outputs: bool; + + // Parameters supported by version 2: + // If true then first dimension is sequence, otherwise batch. + // Version 1 implementations assumed time_major to be true, so this default + // value should never change. + time_major: bool = true; + + // Parameters for version 3 or above. + asymmetric_quantize_inputs:bool; +} + +table ResizeBilinearOptions { + new_height: int (deprecated); + new_width: int (deprecated); + align_corners: bool; + half_pixel_centers: bool; +} + +table ResizeNearestNeighborOptions { + align_corners: bool; + half_pixel_centers: bool; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:uint; +} + +table PadOptions { +} + +table PadV2Options { +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SpaceToBatchNDOptions { +} + +table BatchToSpaceNDOptions { +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +table DepthToSpaceOptions { + block_size: int; +} + +table SubOptions { + fused_activation_function:ActivationFunctionType; + // Parameters supported by version 5 + pot_scale_int16:bool = true; +} + +table DivOptions { + fused_activation_function:ActivationFunctionType; +} + +table TopKV2Options { +} + +enum CombinerType : byte { + SUM = 0, + MEAN = 1, + SQRTN = 2, +} + +table EmbeddingLookupSparseOptions { + combiner:CombinerType; +} + +table GatherOptions { + axis: int; + // Parameters for Gather version 5 or above. + batch_dims: int = 0; +} + +table TransposeOptions { +} + +table ExpOptions { +} + +table CosOptions { +} + +table ReducerOptions { + keep_dims: bool; +} + +table SqueezeOptions { + squeeze_dims:[int]; +} + +table SplitOptions { + num_splits: int; +} + +table SplitVOptions { + num_splits: int; +} + +table StridedSliceOptions { + begin_mask: int; + end_mask: int; + ellipsis_mask: int; + new_axis_mask: int; + shrink_axis_mask: int; + // If true, then the end tensor is an offset of the begin tensor. + offset: bool; +} + +table LogSoftmaxOptions { +} + +table CastOptions { + in_data_type: TensorType; + out_data_type: TensorType; +} + +table DequantizeOptions { +} + +table MaximumMinimumOptions { +} + +table TileOptions { +} + +table ArgMaxOptions { + output_type : TensorType; +} + +table ArgMinOptions { + output_type : TensorType; +} + +table GreaterOptions { +} + +table GreaterEqualOptions { +} + +table LessOptions { +} + +table LessEqualOptions { +} + +table NegOptions { +} + +table SelectOptions { +} + +table SliceOptions { +} + +table TransposeConvOptions { + // Parameters supported by version 1, 2, 3: + padding:Padding; + stride_w:int; + stride_h:int; + + // Parameters supported by version 4: + fused_activation_function:ActivationFunctionType = NONE; + + // Parameters for TransposeConv version 5 or above. + // If set, use this for bias and accumulator. + // When set, quantized_bias_type defines the dtype for both bias and accumulator. + quantized_bias_type: TensorType; +} + +table ExpandDimsOptions { +} + +table SparseToDenseOptions { + validate_indices:bool; +} + +table EqualOptions { +} + +table NotEqualOptions { +} + +table ShapeOptions { + // Optional output type of the operation (int32 or int64). Defaults to int32. + out_type : TensorType; +} + +table RankOptions { +} + +table PowOptions { +} + +table FakeQuantOptions { + // Parameters supported by version 1: + min:float; + max:float; + num_bits:int; + + // Parameters supported by version 2: + narrow_range:bool; +} + +table PackOptions { + values_count:int; + axis:int; +} + +table LogicalOrOptions { +} + +table OneHotOptions { + axis:int; +} + +table AbsOptions { +} + + +table HardSwishOptions { +} + +table LogicalAndOptions { +} + +table LogicalNotOptions { +} + +table UnpackOptions { + num:int; + axis:int; +} + +table FloorDivOptions { +} + +table SquareOptions { +} + +table ZerosLikeOptions { +} + +table FillOptions { +} + +table FloorModOptions { +} + +table RangeOptions { +} + +table LeakyReluOptions { + alpha:float; +} + +table SquaredDifferenceOptions { +} + +// LINT.IfChange +enum MirrorPadMode : byte { + // Doesn't include borders. + REFLECT = 0, + // Includes borders. + SYMMETRIC = 1, +} +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +table MirrorPadOptions { + mode:MirrorPadMode; +} + +table UniqueOptions { + idx_out_type:TensorType = INT32; +} + +table ReverseV2Options { +} + +table AddNOptions { +} + +table GatherNdOptions { +} + +table WhereOptions { +} + +table ReverseSequenceOptions { + seq_dim:int; + batch_dim:int = 0; +} + +table MatrixDiagOptions { +} + +table QuantizeOptions { +} + +table MatrixSetDiagOptions { +} + +table IfOptions { + then_subgraph_index:int; + else_subgraph_index:int; +} + +table CallOnceOptions { + init_subgraph_index:int; +} + +table WhileOptions { + cond_subgraph_index:int; + body_subgraph_index:int; +} + +table NonMaxSuppressionV4Options { +} + +table NonMaxSuppressionV5Options { +} + +table ScatterNdOptions { +} + +table SelectV2Options { +} + +table DensifyOptions { +} + +table SegmentSumOptions { +} + +table BatchMatMulOptions { + adj_x:bool; + adj_y:bool; + // Parameters for BatchMatMul version 4 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; +} + +table CumsumOptions { + exclusive:bool; + reverse:bool; +} + +table BroadcastToOptions { +} + +table Rfft2dOptions { +} + +table HashtableOptions { + // The identity of hash tables. This identity will be used across different + // subgraphs in the same interpreter instance. + table_id:int; + key_dtype:TensorType; + value_dtype:TensorType; +} + +table HashtableFindOptions { +} + +table HashtableImportOptions { +} + +table HashtableSizeOptions { +} + +table VarHandleOptions { + container:string; + shared_name:string; +} + +table ReadVariableOptions { +} + +table AssignVariableOptions { +} + +table RandomOptions { + seed: long; + seed2: long; +} + +table BucketizeOptions { + boundaries: [float]; // The bucket boundaries. +} + +table GeluOptions { + approximate: bool; +} + +table DynamicUpdateSliceOptions { +} + +table UnsortedSegmentProdOptions { +} + +table UnsortedSegmentMaxOptions { +} + +table UnsortedSegmentSumOptions { +} + +table ATan2Options { +} + +table UnsortedSegmentMinOptions{ +} + +table SignOptions { +} + +table BitcastOptions { +} + +table BitwiseXorOptions { +} + +table RightShiftOptions { +} + +table DilateOptions { +} + +enum ReduceWindowFunction : int { + UNSUPPORTED, + ADD, + MUL, + MINIMUM, + MAXIMUM, + ALL, + ANY, +} + +table ReduceWindowOptions (deprecated) { + reduce_function: ReduceWindowFunction; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + // This field is for backward compatibility. This field will be used when + // the value of the extended builtin_code field has less than + // BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. + deprecated_builtin_code:byte; + custom_code:string; + + // The version of the operator. The version need to be bumped whenever new + // parameters are introduced into an op. + version:int = 1; + + // This field is introduced for resolving op builtin code shortage problem + // (the original BuiltinOperator enum field was represented as a byte). + // This field will be used when the value of the extended builtin_code field + // has greater than BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. + builtin_code:BuiltinOperator; +} + +enum CustomOptionsFormat : byte { + FLEXBUFFERS = 0, +} + +table StableHLOCompositeOptions { + name:string; + decomposition_subgraph_index:int32; + composite_attributes:[ubyte]; + composite_attributes_format:CustomOptionsFormat; + version:int32; +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:uint; + + // Optional input are indicated by -1. + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; + custom_options_format:CustomOptionsFormat; + + // A list of booleans indicating the input tensors which are being mutated by + // this operator.(e.g. used by RNN and LSTM). + // For example, if the "inputs" array refers to 5 tensors and the second and + // fifth are mutable variables, then this list will contain + // [false, true, false, false, true]. + // + // If the list is empty, no variable is mutated in this operator. + // The list either has the same length as `inputs`, or is empty. + mutating_variable_inputs:[bool]; + + // A list of indices to the subgraph's "tensors" that are internal to an Op. + // Internal tensors are those that do not flow in or out of the operation, + // but instead are part of internal computation. As such, the operation's + // implementation may manage its memory more efficiently. They are needed + // however (i.e. not just an implementation detail) since they are part of the + // computation, which may require relevant metadata such as quantization + // parameters. + intermediates:[int]; + + // When an op is using custom_options in a model that is larger than 2GB, then + // we instead use the following attributes to find the buffer location which + // is stored outside of flatbuffers, the offset is calculated relative to the + // beginning of the file and is only valid if > 1 + large_custom_options_offset: ulong; + large_custom_options_size: ulong; + + // Flatbuffers union struct has a 128 elements limit in JAVA, so a second + // union is added, in the case of where BuitlinOptions2 runs out, a third + // one can be added + builtin_options_2 : BuiltinOptions2; +} + +// The root type, defining a subgraph, which typically represents an entire +// model. +table SubGraph { + // A list of all tensors used in this subgraph. + tensors:[Tensor]; + + // Indices of the tensors that are inputs into this subgraph. Note this is + // the list of non-static tensors that feed into the subgraph for inference. + inputs:[int]; + + // Indices of the tensors that are outputs out of this subgraph. Note this is + // the list of output tensors that are considered the product of the + // subgraph's inference. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of this subgraph (used for debugging). + name:string; +} + +// Table of raw data buffers (used for constant tensors). Referenced by tensors +// by index. The generous alignment accommodates mmap-friendly data structures. +table Buffer { + data:[ubyte] (force_align: 16); + + // In a model that is larger than 2GB, then buffers instead uses the following + // attributes to find stored data, which is outside of flatbuffers + // the offset is calculated relative to the beginning of the file and is only + // valid if > 1. + offset: ulong; + size: ulong; +} + +table Metadata { + // A human readable string to uniquely identify a Metadata. + name:string; + // An index to the buffers table. + buffer:uint; +} + +// Map from an alias name of tensor to tensor index in the graph. +// This is used in Signature def. +table TensorMap { + // Represents the alias to use for this tensor. + name:string; + + // The actual tensor index in the primary graph, that 'name' corresponds to. + tensor_index:uint; +} + +// This corresponds to SignatureDef in Tensorflow SavedModel. +// The SignatureDef will be part of the SavedModel provided for conversion. +table SignatureDef { + // Named inputs for this signature. + inputs:[TensorMap]; + + // Named outputs for this signature. + outputs:[TensorMap]; + + // Key value which was in the Tensorflow SavedModel SignatureDef map. + signature_key:string; + + // Model tag, deprecated. + deprecated_tag:string (deprecated); + + // Index of subgraphs that corresponds to the exported method. + subgraph_index:uint; +} + +table Model { + // Version of the schema. + version:uint; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; + + // Buffers of the model. + // Note the 0th entry of this array must be an empty buffer (sentinel). + // This is a convention so that tensors without a buffer can provide 0 as + // their buffer. + buffers:[Buffer]; + + // Metadata about the model. Indirects into the existings buffers list. + // Deprecated, prefer to use metadata field. + metadata_buffer:[int]; + + // Metadata about the model. + metadata:[Metadata]; + + // Optional SignatureDefs for the model. + signature_defs:[SignatureDef]; +} + +root_type Model; diff --git a/tensorflow/compiler/mlir/lite/sparsity/BUILD b/tensorflow/compiler/mlir/lite/sparsity/BUILD index fce754995766d5..0566ea545a7b0b 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/BUILD +++ b/tensorflow/compiler/mlir/lite/sparsity/BUILD @@ -30,12 +30,12 @@ cc_library( "//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", "//tensorflow/compiler/mlir/lite:tensorflow_lite_d2s", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/core:protos_all_cc", "//tensorflow/lite:framework", "//tensorflow/lite/core/api", "//tensorflow/lite/core/c:private_c_api_types", - "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/tools/optimize:reduced_precision_support", "@com_google_absl//absl/strings", "@flatbuffers", @@ -54,10 +54,10 @@ tf_cc_test( ], deps = [ ":sparsify_model", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/lite/core:model_builder", "//tensorflow/lite/core/api:error_reporter", "//tensorflow/lite/core/c:private_c_api_types", - "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/tools/optimize:reduced_precision_support", "@com_google_googletest//:gtest_main", "@flatbuffers", diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h index 53deff6d990bb0..e0659063bc4ec6 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h @@ -16,9 +16,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_SPARSITY_SPARSIFY_MODEL_H_ #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/c_api_types.h" -#include "tensorflow/lite/schema/schema_generated.h" namespace mlir { namespace lite { diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc index 861a02be9caa6c..71fc1927a02217 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc @@ -26,10 +26,10 @@ limitations under the License. #include #include #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/core/model_builder.h" -#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/tools/optimize/reduced_precision_support.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 9976d6ff363c8f..e4001d4c08b695 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -64,9 +64,11 @@ cc_library( deps = [ ":stablehlo_util", "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], alwayslink = 1, ) @@ -545,6 +547,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", @@ -752,6 +755,27 @@ cc_library( alwayslink = True, ) +cc_library( + name = "optimize_layout", + srcs = [ + "transforms/optimize_layout.cc", + ], + hdrs = ["transforms/passes.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:stablehlo_ops", + ], + alwayslink = 1, +) + gentbl_cc_library( name = "composite_lowering_inc_gen", compatible_with = get_compatible_with_portable(), @@ -794,7 +818,6 @@ tf_cc_binary( "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", - "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton_impl", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_graph_optimization_pass", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD index c487600517f9b8..5add6c730cac5e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD @@ -1,5 +1,6 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,12 +19,14 @@ package_group( tf_cc_binary( name = "odml-converter", srcs = ["odml_converter_main.cc"], + compatible_with = get_compatible_with_portable(), visibility = [ "//tensorflow/compiler/mlir/lite/stablehlo/odml_converter:__subpackages__", "//third_party/odml/infra:__subpackages__", ], # Prototype phase. deps = [ - ":all_passes", + ":outline_composites", + ":shlo_simplify", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_to_vhlo_pass", @@ -33,12 +36,84 @@ tf_cc_binary( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:Support", + "@stablehlo//:chlo_ops", + "@stablehlo//:stablehlo_ops", + ], +) + +cc_library( + name = "outline_composites", + srcs = [ + "transforms/outline_composites.cc", + ], + hdrs = ["passes.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BufferizationInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:chlo_ops", + "@stablehlo//:stablehlo_ops", + ], + alwayslink = 1, +) + +cc_library( + name = "shlo_simplify", + srcs = [ + "transforms/shlo_simplify.cc", + ], + hdrs = ["passes.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":folders", + ":passes_inc_gen", + ":shlo_simplify_inc_gen", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:stablehlo_ops", + ], + alwayslink = 1, +) + +gentbl_cc_library( + name = "shlo_simplify_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + ["-gen-rewriters"], + "transforms/generated_shlo_simplify.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "transforms/shlo_simplify.td", + deps = ["@stablehlo//:stablehlo_ops_td_files"], +) + +cc_library( + name = "folders", + srcs = ["folders.cc"], + hdrs = ["folders.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@stablehlo//:stablehlo_ops", ], ) gentbl_cc_library( name = "passes_inc_gen", + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -52,13 +127,3 @@ gentbl_cc_library( td_file = "passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) - -cc_library( - name = "all_passes", - hdrs = ["passes.h"], - deps = [":passes_inc_gen"], -) - -exports_files([ - "run_lit.sh", -]) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc new file mode 100644 index 00000000000000..cb48050db47cb5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc @@ -0,0 +1,129 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo + +namespace mlir::odml { + +namespace { + +// Helper class for parsing operands to a foldable operation. +class FoldAdaptor { + public: + // Returns std::nullopt if the operation cannot be folded. + static std::optional Create(Operation* operation) { + auto foldable_opr = [](Value val) -> bool { + return !llvm::isa(val) && + llvm::isa(val.getDefiningOp()); + }; + if (!llvm::all_of(operation->getOperands(), foldable_opr)) { + return std::nullopt; + } + return FoldAdaptor(operation); + } + + // Gets a list of ElementsAttr behind each constant operand. + llvm::SmallVector OperandData() { + llvm::SmallVector res; + res.reserve(operation_->getNumOperands()); + for (auto opr : operation_->getOperands()) { + auto op = llvm::dyn_cast(opr.getDefiningOp()); + res.push_back(op.getValue()); + } + return res; + } + + // Gets a pointer to the operation to be folded. + Operation* Op() { return operation_; } + + private: + explicit FoldAdaptor(Operation* operation) : operation_(operation) {} + Operation* const operation_; +}; + +// APSInt provides operators which APInt does not, so allow for converting +// to APSInt for computation. Only APInts can be directly read from +// element attributes. +static const APFloat& AddSign(const APFloat& v) { return v; } +static APSInt AddSign(const APInt& v) { return APSInt(v); } + +template +static LogicalResult FoldDivOpInternal(stablehlo::DivOp op, + PatternRewriter& rewriter) { + auto adaptor = FoldAdaptor::Create(op); + if (!adaptor.has_value()) { + return failure(); + } + auto const_oprs = adaptor.value().OperandData(); + + const bool lhs_splat = const_oprs[0].isSplat(); + const bool rhs_splat = const_oprs[1].isSplat(); + + auto lhs_vals = const_oprs[0].getValues(); + auto rhs_vals = const_oprs[1].getValues(); + const auto num_results = std::max(lhs_vals.size(), rhs_vals.size()); + std::vector res; + res.reserve(num_results); + + auto lhs_start = lhs_vals.begin(); + auto rhs_start = rhs_vals.begin(); + + for (int i = 0; i < num_results; ++i) { + auto lhs_val = lhs_splat ? *lhs_start : *(lhs_start++); + auto rhs_val = rhs_splat ? *rhs_start : *(rhs_start++); + auto signed_lhs_val = AddSign(lhs_val); + auto signed_rhs_val = AddSign(rhs_val); + if (signed_rhs_val.isZero()) { + return failure(); + } + res.push_back(signed_lhs_val / signed_rhs_val); + } + + auto res_attr = DenseElementsAttr::get( + const_oprs[0].getType().cast(), res); + rewriter.replaceOpWithNewOp(adaptor.value().Op(), + res_attr); + return success(); +} + +static LogicalResult FoldDivOp(stablehlo::DivOp op, PatternRewriter& rewriter) { + auto etype = op.getType().getElementType(); + if (etype.isa()) { + return FoldDivOpInternal(op, rewriter); + } + if (etype.isa()) { + return FoldDivOpInternal(op, rewriter); + } + return failure(); +} +} // namespace + +void PopulateFolderPatterns(RewritePatternSet& patternSet) { + patternSet.add(FoldDivOp); +} + +} // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass_registration.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.h similarity index 52% rename from tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass_registration.cc rename to tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.h index 3359ba08fd15df..6f3d2d55b33252 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass_registration.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,12 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_FOLDERS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_FOLDERS_H_ -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h" +namespace mlir::odml { -namespace tensorflow { +// Populates the pattern set with all folding patterns. These patterns +// are intended to have precedence over any other patterns added to the set. +void PopulateFolderPatterns(RewritePatternSet &patternSet); -REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 10, - MlirRoundtripPass); +} // namespace mlir::odml -} // namespace tensorflow +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_FOLDERS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/odml_converter_main.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/odml_converter_main.cc index ecd7396c2a4622..a510e640a7abd8 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/odml_converter_main.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/odml_converter_main.cc @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -42,7 +43,7 @@ int main(int argc, char* argv[]) { mlir::DialectRegistry registry; registry.insert(); + mlir::TF::TensorFlowDialect, mlir::chlo::ChloDialect>(); return failed( mlir::MlirOptMain(argc, argv, "ODML Converter Driver\n", registry)); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h index b3589356f196a2..42e5d18e11f965 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h @@ -16,8 +16,18 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_PASSES_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_PASSES_H_ +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + namespace mlir::odml { +std::unique_ptr> CreateOutlineCompositesPass(); + +std::unique_ptr> CreateSHLOSimplifyPass(); + #define GEN_PASS_REGISTRATION #include "tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h.inc" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.td b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.td index 800d7e0d2ff59b..45360d8749dc84 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.td @@ -15,3 +15,29 @@ limitations under the License. include "mlir/Pass/PassBase.td" +def OutlineCompositesPass: Pass<"outline-composites", "func::FuncOp"> { + let summary = "Outlines specific patterns into composites."; + let description = [{ + Outline specific patterns into composites. Specific patterns can be any + sub-DAG within a single `Block*`. The signature of the new composite + matches the inupt and output edges from a node in the sub-DAG to a node out + of it. The associated decomposition has the same semantic as the matched + ops, but may not have identical structure. + }]; + + let options = []; + let constructor = "CreateOutlineCompositesPass()"; + let dependentDialects = ["mlir::chlo::ChloDialect", "mlir::stablehlo::StablehloDialect", "mlir::func::FuncDialect"]; +} + +def SHLOSimplifyPass: Pass<"shlo-simplify", "ModuleOp"> { + let summary = "Apply internal canonicalizations and foldings."; + let description = [{ + Applies various internally defined patterns. + }]; + + let options = [ + + ]; + let constructor = "CreateSHLOSimplifyPass()"; +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/tests/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/tests/BUILD index c990b20c8fb51c..c78441dcfc446a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/tests/BUILD @@ -21,5 +21,6 @@ filegroup( data = [ "//tensorflow/compiler/mlir/lite/stablehlo/odml_converter:odml-converter", "@llvm-project//llvm:FileCheck", + "@llvm-project//mlir:run_lit.sh", ], ) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/tests/outline_composites.mlir b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/tests/outline_composites.mlir new file mode 100644 index 00000000000000..81726aaba37bcc --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/tests/outline_composites.mlir @@ -0,0 +1,57 @@ +// RUN: odml-converter --outline-composites %s -split-input-file | FileCheck %s + +func.func @geluWithCustomCallErf(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = stablehlo.constant dense<1.000000e+00> : tensor<2xf32> + %1 = stablehlo.constant dense<0.707106769> : tensor<2xf32> + %2 = stablehlo.constant dense<5.000000e-01> : tensor<2xf32> + %3 = stablehlo.multiply %arg0, %2 : tensor<2xf32> + %4 = stablehlo.multiply %arg0, %1 : tensor<2xf32> + %5 = stablehlo.custom_call @mhlo.erf(%4) {mhlo.attributes = {}, mhlo.version = 1 : i64} : (tensor<2xf32>) -> tensor<2xf32> + %6 = stablehlo.add %5, %0 : tensor<2xf32> + %7 = stablehlo.multiply %3, %6 : tensor<2xf32> + return %7 : tensor<2xf32> +} + +// CHECK: func.func private @gelu_decomp_0(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %cst = stablehlo.constant dense<1.000000e+00> : tensor<2xf32> +// CHECK: %cst_0 = stablehlo.constant dense<5.000000e-01> : tensor<2xf32> +// CHECK: %cst_1 = stablehlo.constant dense<0.707106769> : tensor<2xf32> +// CHECK: %0 = stablehlo.multiply %arg0, %cst_1 : tensor<2xf32> +// CHECK: %1 = chlo.erf %0 : tensor<2xf32> -> tensor<2xf32> +// CHECK: %2 = stablehlo.add %1, %cst : tensor<2xf32> +// CHECK: %3 = stablehlo.multiply %arg0, %cst_0 : tensor<2xf32> +// CHECK: %4 = stablehlo.multiply %3, %2 : tensor<2xf32> +// CHECK: return %4 : tensor<2xf32> + +// CHECK-LABEL: geluWithCustomCallErf +// CHECK: %0 = stablehlo.composite "odml.internal.gelu" %arg0 {composite_attributes = {approx = false}, decomposition = @gelu_decomp_0} : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %0 + +// ----- + +func.func @geluWithCHLOErf(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = stablehlo.constant dense<1.000000e+00> : tensor<2xf32> + %1 = stablehlo.constant dense<0.707106769> : tensor<2xf32> + %2 = stablehlo.constant dense<5.000000e-01> : tensor<2xf32> + %3 = stablehlo.multiply %arg0, %2 : tensor<2xf32> + %4 = stablehlo.multiply %arg0, %1 : tensor<2xf32> + %5 = chlo.erf %4 : tensor<2xf32> -> tensor<2xf32> + %6 = stablehlo.add %5, %0 : tensor<2xf32> + %7 = stablehlo.multiply %3, %6 : tensor<2xf32> + return %7 : tensor<2xf32> +} + +// CHECK: func.func private @gelu_decomp_0(%arg0: tensor<2xf32>) -> tensor<2xf32> +// CHECK: %cst = stablehlo.constant dense<1.000000e+00> : tensor<2xf32> +// CHECK: %cst_0 = stablehlo.constant dense<5.000000e-01> : tensor<2xf32> +// CHECK: %cst_1 = stablehlo.constant dense<0.707106769> : tensor<2xf32> +// CHECK: %0 = stablehlo.multiply %arg0, %cst_1 : tensor<2xf32> +// CHECK: %1 = chlo.erf %0 : tensor<2xf32> -> tensor<2xf32> +// CHECK: %2 = stablehlo.add %1, %cst : tensor<2xf32> +// CHECK: %3 = stablehlo.multiply %arg0, %cst_0 : tensor<2xf32> +// CHECK: %4 = stablehlo.multiply %3, %2 : tensor<2xf32> +// CHECK: return %4 : tensor<2xf32> + +// CHECK-LABEL: geluWithCHLOErf +// CHECK: %0 = stablehlo.composite "odml.internal.gelu" %arg0 {composite_attributes = {approx = false}, decomposition = @gelu_decomp_0} : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %0 diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/tests/shlo_simplify.mlir b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/tests/shlo_simplify.mlir new file mode 100644 index 00000000000000..e06431a1852b3b --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/tests/shlo_simplify.mlir @@ -0,0 +1,96 @@ +// RUN: odml-converter --shlo-simplify %s -split-input-file | FileCheck %s + +func.func @foldDiv() -> tensor<2xf32> { + %0 = stablehlo.constant dense<[2.0, 3.0]> : tensor<2xf32> + %1 = stablehlo.constant dense<[4.0, 6.0]> : tensor<2xf32> + %2 = stablehlo.divide %0, %1 : tensor<2xf32> + return %2 : tensor<2xf32> +} + +// CHECK-LABEL: foldDiv +// CHECK: stablehlo.constant dense<5.000000e-01> : tensor<2xf32> + +// ----- + +func.func @foldDivLHSSplat() -> tensor<2xf32> { + %0 = stablehlo.constant dense<2.0> : tensor<2xf32> + %1 = stablehlo.constant dense<[4.0, 6.0]> : tensor<2xf32> + %2 = stablehlo.divide %0, %1 : tensor<2xf32> + return %2 : tensor<2xf32> +} + +// CHECK-LABEL: foldDivLHSSplat +// CHECK: stablehlo.constant dense<[5.000000e-01, 0.333333343]> : tensor<2xf32> + +// ----- + +func.func @foldDivRHSSplat() -> tensor<2xf32> { + %0 = stablehlo.constant dense<[4.0, 6.0]> : tensor<2xf32> + %1 = stablehlo.constant dense<2.0> : tensor<2xf32> + %2 = stablehlo.divide %0, %1 : tensor<2xf32> + return %2 : tensor<2xf32> +} + +// CHECK-LABEL: foldDivRHSSplat +// CHECK: stablehlo.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32> + +// ----- + +func.func @foldDivBothSplat() -> tensor<2xf32> { + %0 = stablehlo.constant dense<4.0> : tensor<2xf32> + %1 = stablehlo.constant dense<2.0> : tensor<2xf32> + %2 = stablehlo.divide %0, %1 : tensor<2xf32> + return %2 : tensor<2xf32> +} + +// CHECK-LABEL: foldDivBothSplat +// CHECK: stablehlo.constant dense<2.000000e+00> : tensor<2xf32> + +// ----- + +func.func @foldDivF64() -> tensor<2xf64> { + %0 = stablehlo.constant dense<[2.0, 3.0]> : tensor<2xf64> + %1 = stablehlo.constant dense<[4.0, 6.0]> : tensor<2xf64> + %2 = stablehlo.divide %0, %1 : tensor<2xf64> + return %2 : tensor<2xf64> +} + +// CHECK-LABEL: foldDivF64 +// CHECK: stablehlo.constant dense<5.000000e-01> : tensor<2xf64> + +// ----- + +func.func @foldDivI32() -> tensor<2xi32> { + %0 = stablehlo.constant dense<[9, 3]> : tensor<2xi32> + %1 = stablehlo.constant dense<[4, 6]> : tensor<2xi32> + %2 = stablehlo.divide %0, %1 : tensor<2xi32> + return %2 : tensor<2xi32> +} + +// CHECK-LABEL: foldDivI32 +// CHECK: stablehlo.constant dense<[2, 0]> : tensor<2xi32> + +// ----- + +func.func @divideToMulReciprocalSplat(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = stablehlo.constant dense<2.0> : tensor<2xf32> + %2 = stablehlo.divide %arg0, %0 : tensor<2xf32> + return %2 : tensor<2xf32> +} + +// CHECK-LABEL: divideToMulReciprocalSplat +// CHECK: stablehlo.constant dense<5.000000e-01> : tensor<2xf32> +// CHECK: stablehlo.multiply + +// ----- + +func.func @divideToMulReciprocal(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = stablehlo.constant dense<[2.0, 3.0]> : tensor<2xf32> + %2 = stablehlo.divide %arg0, %0 : tensor<2xf32> + return %2 : tensor<2xf32> +} + +// CHECK-LABEL: divideToMulReciprocal +// CHECK: stablehlo.constant dense<[5.000000e-01, 0.333333343]> : tensor<2xf32> +// CHECK: stablehlo.multiply + diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/outline_composites.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/outline_composites.cc new file mode 100644 index 00000000000000..821ba4fa7e4d2f --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/outline_composites.cc @@ -0,0 +1,252 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo + +namespace mlir { +namespace odml { +namespace { + +// TODO - b/330337238: Surface these to other files when needed. +constexpr llvm::StringLiteral kCompositeNamespace = "odml.internal"; +constexpr llvm::StringLiteral kGelu = "gelu"; + +std::string MakeCompositeName(llvm::StringRef op_name) { + return (kCompositeNamespace + "." + op_name).str(); +} + +#define GEN_PASS_DEF_OUTLINECOMPOSITESPASS +#include "tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h.inc" + +constexpr float kOne = 1.0; +const float kOneOverRoot2 = kOne / std::sqrt(2); +constexpr float kHalf = kOne / 2.0; +constexpr float kTolerance = kOne / 1000.0; + +// Gets the operation that uses the sole result of given operation +// if there is only one. +Operation* GetUserIfOnlyOne(Operation* op) { + if (op->getNumResults() != 1) return nullptr; + auto result = op->getResult(0); + if (!result.hasOneUse()) return nullptr; + return (*result.getUses().begin()).getOwner(); +} + +// Gets operation providing value for the given operand of given operation +// if the given operation is the only user. +Operation* GetInputOpWithOneUse(Operation* op, int opr_num) { + if (opr_num >= op->getNumOperands()) return nullptr; + auto opr = op->getOperand(opr_num); + if (llvm::isa(opr)) return nullptr; + auto* res = opr.getDefiningOp(); + if (!res->hasOneUse()) return nullptr; + return res; +} + +// Checks if the given operand of given operation refers to a splat constant +// with given val. +bool HasSplatArg(Operation* op, float val, int opr_num) { + auto* cst_input = GetInputOpWithOneUse(op, 1); + if (!cst_input) return false; + auto cst_op = llvm::dyn_cast_or_null(cst_input); + if (!cst_op) return false; + ElementsAttr value = cst_op.getValue(); + if (!value.isSplat()) return false; + if (!value.getElementType().isF32()) return false; + return std::abs(value.getSplatValue() - val) < kTolerance; +} + +// Determines if the given op is semantically that of the gauss error function. +bool MatchERF(Operation* op) { + if (auto custom_call = llvm::dyn_cast_or_null(op)) { + return custom_call.getCallTargetName() == "mhlo.erf"; + } + return llvm::isa(op); +} + +// Builds a reference implementation of non-approximate GELU. +func::FuncOp BuildGELUDecomposition(RankedTensorType type, + PatternRewriter& rewriter, + Block* insertion_point) { + rewriter.setInsertionPointToStart(insertion_point); + + auto ftype = FunctionType::get(rewriter.getContext(), {type}, {type}); + auto name = rewriter.getStringAttr("gelu_decomp"); + func::FuncOp new_func = rewriter.create( + insertion_point->front().getLoc(), name, ftype); + new_func.setPrivate(); + new_func.addEntryBlock(); + rewriter.setInsertionPointToStart(&new_func.getBody().front()); + + auto one_val = DenseElementsAttr::get(type, kOne); + auto one_cst = + rewriter.create(rewriter.getUnknownLoc(), one_val); + + auto half_val = DenseElementsAttr::get(type, kHalf); + auto half_cst = + rewriter.create(one_cst.getLoc(), half_val); + + auto one_over_root2_val = DenseElementsAttr::get(type, kOneOverRoot2); + auto one_over_root2_cst = rewriter.create( + half_cst.getLoc(), one_over_root2_val); + + auto mul_op = rewriter.create(one_over_root2_cst.getLoc(), + new_func.getArguments()[0], + one_over_root2_cst); + auto erf_op = rewriter.create(mul_op.getLoc(), mul_op); + auto add_op = + rewriter.create(erf_op.getLoc(), erf_op, one_cst); + auto lhs_mul_op = rewriter.create( + half_cst.getLoc(), new_func.getArguments()[0], half_cst); + auto output_mul_op = rewriter.create(lhs_mul_op.getLoc(), + lhs_mul_op, add_op); + + rewriter.create(output_mul_op.getLoc(), + output_mul_op.getResult()); + rewriter.clearInsertionPoint(); + return new_func; +} + +// Outlines non-approximate GELU into a stablehlo composite. +// +// -> mul 1/sqrt(2) -> erf -> add 1 -> +// in mul +// ---------> mul 0.5 ---------------> +// +// This pattern assumes all binary ewise ops with one constant argument +// have that constant argument as the second operand. It works by +// identifying `erf` ops and validate the structure around them. +class OutlineGELU : public RewritePattern { + public: + explicit OutlineGELU(MLIRContext* context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + if (!MatchERF(op)) return failure(); + // `add 1` + auto* erf_user = GetUserIfOnlyOne(op); + if (!erf_user) return failure(); + + // `mul` + auto* erf_user_user = GetUserIfOnlyOne(erf_user); + if (!erf_user_user) return failure(); + + // `mul 1/sqrt(2)` + auto* erf_input = GetInputOpWithOneUse(op, 0); + if (!erf_input) return failure(); + + // `mul 0.5` + auto* erf_user_user_input = GetInputOpWithOneUse(erf_user_user, 0); + if (!erf_user_user_input) return failure(); + + // Check `mul 0.5` and `mul 1/sqrt(2)` refer to the same input. + if (erf_user_user_input->getOperand(0) != erf_input->getOperand(0)) { + return failure(); + } + + // Check the structural matches have the correct op type and values. + auto rhs_mul = llvm::dyn_cast_or_null(erf_input); + if (!rhs_mul) return failure(); + + auto lhs_mul = + llvm::dyn_cast_or_null(erf_user_user_input); + if (!lhs_mul) return failure(); + + auto output_mul = llvm::dyn_cast_or_null(erf_user_user); + if (!output_mul) return failure(); + + auto rhs_add = llvm::dyn_cast_or_null(erf_user); + if (!rhs_add) return failure(); + + if (!HasSplatArg(rhs_add, kOne, 1)) return failure(); + if (!HasSplatArg(lhs_mul, kHalf, 1)) return failure(); + if (!HasSplatArg(rhs_mul, kOneOverRoot2, 1)) return failure(); + + // Build a function to serve as the GELU decomposition in the + // shlo composite op. + auto root = op->getParentOfType(); + auto func = BuildGELUDecomposition( + rhs_add.getType().cast(), rewriter, root.getBody()); + + SymbolTable table(root); + (void)table.renameToUnique(func, {}); + + rewriter.setInsertionPointAfter(output_mul); + auto composite_attrs = rewriter.getDictionaryAttr( + {rewriter.getNamedAttr("approx", rewriter.getBoolAttr(false))}); + auto composite_op = rewriter.create( + output_mul.getLoc(), func.getResultTypes()[0], + SmallVector{erf_input->getOperand(0)}, MakeCompositeName(kGelu), + composite_attrs, func.getSymName()); + rewriter.replaceAllOpUsesWith(output_mul, composite_op); + // Note these must be erased in reverse topo order to avoid + // failing in debug mode. + rewriter.eraseOp(output_mul); + rewriter.eraseOp(rhs_add); + rewriter.eraseOp(op); + rewriter.eraseOp(lhs_mul); + rewriter.eraseOp(rhs_mul); + + return success(); + } +}; + +class OutlineCompositesPass + : public impl::OutlineCompositesPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OutlineCompositesPass) + + void runOnOperation() override { + auto func = getOperation(); + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> CreateOutlineCompositesPass() { + return std::make_unique(); +} + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.cc new file mode 100644 index 00000000000000..668fe06515812e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.cc @@ -0,0 +1,60 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.h" + +namespace mlir { +namespace odml { +namespace { + +#define GEN_PASS_DEF_SHLOSIMPLIFYPASS +#include "tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/generated_shlo_simplify.inc" + +// Performs misc odml "cleanup" on shlo dialect. This is a functional standin +// for canonicalization and folding which is not offered directly by the +// shlo implementation. +class SHLOSimplifyPass : public impl::SHLOSimplifyPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SHLOSimplifyPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + RewritePatternSet patterns(&getContext()); + populateWithGenerated(patterns); + PopulateFolderPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> CreateSHLOSimplifyPass() { + return std::make_unique(); +} + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.td b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.td new file mode 100644 index 00000000000000..c8d19baeb11d0d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.td @@ -0,0 +1,38 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +include "stablehlo/dialect/StablehloOps.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/CommonAttrConstraints.td" +include "mlir/IR/CommonTypeConstraints.td" + +def CloneF32ElementsAttrWithOnes + : NativeCodeCall<"DenseElementsAttr::get($0.getType().cast(), (float)1.0)">; + +def NotConstant : Constraint< + CPred<"$0.isa() || !llvm::isa($0.getDefiningOp())">, + "Is not a constant.">; + +def : Pat<(StableHLO_DivOp $l, + (StableHLO_ConstantOp:$divisor FloatElementsAttr<32>:$cst)), + (StableHLO_MulOp $l, + (StableHLO_DivOp + (StableHLO_ConstantOp (CloneF32ElementsAttrWithOnes $cst)), + $divisor)), + [(NotConstant $l)]>; + + + + diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc index f1d6b237ac2ef6..28afcd43a03218 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc @@ -63,7 +63,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "xla/mlir/framework/transforms/passes.h" -#include "xla/mlir_hlo/lhlo/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "tensorflow/core/platform/errors.h" @@ -156,7 +155,7 @@ opt exported_model_signatures( namespace mlir { namespace odml { -tensorflow::StatusOr> ImportSavedModelOrMLIR( +absl::StatusOr> ImportSavedModelOrMLIR( const std::string& input_path, MLIRContext* context, llvm::SourceMgr* source_mgr, std::unique_ptr* saved_model_bundle) { @@ -215,7 +214,7 @@ tensorflow::Status ExportModule(mlir::ModuleOp module, output->os() << result; output->keep(); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } tensorflow::Status ConvertTFToStableHLO( @@ -261,7 +260,7 @@ tensorflow::Status ConvertTFToStableHLO( return tensorflow::errors::Aborted("Lowering to StableHLO failed."); } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } tensorflow::Status RunConverter(const PassPipelineCLParser& pass_pipeline) { @@ -352,7 +351,6 @@ void initAllPasses() { mlir::registerAllPasses(); mlir::registerTensorFlowPasses(); mlir::mhlo::registerAllMhloPasses(); - mlir::lmhlo::registerAllLmhloPasses(); // These are in compiler/mlir/tf2xla and not part of the above MHLO passes. mlir::mhlo::registerTfXlaPasses(); mlir::mhlo::registerLegalizeTFPass(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/compose-uniform-quantized-type.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/compose-uniform-quantized-type.mlir index b98d8af67ccd29..7e60dc85a487a6 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/compose-uniform-quantized-type.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/compose-uniform-quantized-type.mlir @@ -27,7 +27,7 @@ module { %20 = call @uniform_dequantize_0(%19, %5, %6) : (tensor<1x3x3x4xi8>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xi8>) -> tensor<1x3x3x4xf32> return %20 : tensor<1x3x3x4xf32> } -// CHECK: %[[FILTER:.*]] = stablehlo.constant() {value = dense<1> : tensor<3x3x4x4xi8>} : () -> tensor<3x3x4x4x!quant.uniform> +// CHECK: %[[FILTER:.*]] = stablehlo.constant() <{value = dense<1> : tensor<3x3x4x4xi8>}> : () -> tensor<3x3x4x4x!quant.uniform> // CHECK: %[[QUANT_ARG:.*]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x3x3x4xf32>) -> tensor<1x3x3x4x!quant.uniform> // CHECK: %[[CONV:.*]] = stablehlo.convolution(%[[QUANT_ARG]], %[[FILTER]]) {{.*}} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x4x!quant.uniform>) -> tensor<1x3x3x4x!quant.uniform> // CHECK: %[[DEQUANT:.*]] = stablehlo.uniform_dequantize %[[CONV]] : (tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x3x4xf32> @@ -87,7 +87,7 @@ module { %18 = call @uniform_dequantize_0(%17, %5, %6) : (tensor<1x3x3x4xi8>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xi8>) -> tensor<1x3x3x4xf32> return %18 : tensor<1x3x3x4xf32> } -// CHECK: %[[FILTER:.*]] = stablehlo.constant() {value = dense<20> : tensor<3x3x4x4xi8>} : () -> tensor<3x3x4x4x!quant.uniform> +// CHECK: %[[FILTER:.*]] = stablehlo.constant() <{value = dense<20> : tensor<3x3x4x4xi8>}> : () -> tensor<3x3x4x4x!quant.uniform> // CHECK: %[[QUANT_ARG:.*]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x3x3x4xf32>) -> tensor<1x3x3x4x!quant.uniform> // CHECK: %[[CONV:.*]] = stablehlo.convolution(%[[QUANT_ARG]], %[[FILTER]]) {{.*}} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x4x!quant.uniform>) -> tensor<1x3x3x4x!quant.uniform> // CHECK: %[[DEQUANT:.*]] = stablehlo.uniform_dequantize %[[CONV]] : (tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x3x4xf32> @@ -182,7 +182,7 @@ module { return %17 : tensor<1x4x3xf32> } // Quantization dimension == 1 because it is the output feature dimension. -// CHECK: %[[FILTER:.*]] = stablehlo.constant() {value = dense<5> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform> +// CHECK: %[[FILTER:.*]] = stablehlo.constant() <{value = dense<5> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform> // CHECK: %[[QUANT_ARG:.*]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x4x2xf32>) -> tensor<1x4x2x!quant.uniform> // CHECK: %[[CONV:.*]] = stablehlo.dot_general %[[QUANT_ARG]], %[[FILTER]], contracting_dims = [2] x [0] : (tensor<1x4x2x!quant.uniform>, tensor<2x3x!quant.uniform>) -> tensor<1x4x3x!quant.uniform> // CHECK: %[[DEQUANT:.*]] = stablehlo.uniform_dequantize %[[CONV]] : (tensor<1x4x3x!quant.uniform>) -> tensor<1x4x3xf32> @@ -238,7 +238,7 @@ module { } // Quantization dimension == 1 because it is the output feature dimension. // Quantized filter values (from f32 constant) are cast to i8. -// CHECK: %[[FILTER:.*]] = stablehlo.constant() {value = dense<5> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform> +// CHECK: %[[FILTER:.*]] = stablehlo.constant() <{value = dense<5> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform> // CHECK: %[[QUANT_ARG:.*]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x4x2xf32>) -> tensor<1x4x2x!quant.uniform> // CHECK: %[[CONV:.*]] = stablehlo.dot_general %[[QUANT_ARG]], %[[FILTER]], contracting_dims = [2] x [0] : (tensor<1x4x2x!quant.uniform>, tensor<2x3x!quant.uniform>) -> tensor<1x4x3x!quant.uniform> // CHECK: %[[DEQUANT:.*]] = stablehlo.uniform_dequantize %[[CONV]] : (tensor<1x4x3x!quant.uniform>) -> tensor<1x4x3xf32> @@ -292,7 +292,7 @@ module { return %15 : tensor<1x3xf32> } // Quantization dimension == 1 because it is the output feature dimension. -// CHECK: %[[FILTER:.*]] = stablehlo.constant() {value = dense<5> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform> +// CHECK: %[[FILTER:.*]] = stablehlo.constant() <{value = dense<5> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform> // CHECK: %[[QUANT_ARG:.*]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> // CHECK: %[[CONV:.*]] = stablehlo.dot_general %[[QUANT_ARG]], %[[FILTER]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> // CHECK: %[[DEQUANT:.*]] = stablehlo.uniform_dequantize %[[CONV]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> @@ -431,8 +431,8 @@ module { %2 = stablehlo.constant dense<-128> : tensor<1x1x1xi8> // Input 1 zero point (z1). %3 = stablehlo.constant dense<-128> : tensor<1x1x1xi32> // Input 1 zero point (z1) (upcast & folded into i32). %4 = stablehlo.constant dense<4.000000e-01> : tensor<1x1x1xf32> // Input 2 inverse scale (1 / s2). - %5 = stablehlo.constant dense<-3> : tensor<1x1x1xi8> // Input 2 zero point (z2). - %6 = stablehlo.constant dense<-3> : tensor<1x1x1xi32> // Input 2 zero point (z2) (upcast & folded into i32). + %5 = stablehlo.constant dense<0> : tensor<1x1x1xi8> // Input 2 zero point (z2). + %6 = stablehlo.constant dense<0> : tensor<1x1x1xi32> // Input 2 zero point (z2) (upcast & folded into i32). %7 = stablehlo.constant dense<5.000000e-01> : tensor<1x1x1xf32> // Output inverse scale (1 / s3). %8 = stablehlo.constant dense<-5> : tensor<1x1x1xi8> // Output zero point (z3). %9 = stablehlo.constant dense<1.250000e+01> : tensor<1x1x1xf32> // Merged scale (s1 * s2). @@ -454,8 +454,8 @@ module { return %23 : tensor<8x16x4xf32> } // CHECK: %[[UQ_0:.*]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<8x16x16xf32>) -> tensor<8x16x16x!quant.uniform> -// CHECK: %[[UQ_1:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<8x16x4xf32>) -> tensor<8x16x4x!quant.uniform> -// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %[[UQ_0]], %[[UQ_1]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<8x16x16x!quant.uniform>, tensor<8x16x4x!quant.uniform>) -> tensor<8x16x4x!quant.uniform> +// CHECK: %[[UQ_1:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<8x16x4xf32>) -> tensor<8x16x4x!quant.uniform> +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %[[UQ_0]], %[[UQ_1]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<8x16x16x!quant.uniform>, tensor<8x16x4x!quant.uniform>) -> tensor<8x16x4x!quant.uniform> // CHECK: %[[DQ_0:.*]] = stablehlo.uniform_dequantize %[[DOT_GENERAL]] : (tensor<8x16x4x!quant.uniform>) -> tensor<8x16x4xf32> // CHECK: return %[[DQ_0]] @@ -492,7 +492,7 @@ module { %1 = stablehlo.constant dense<2.000000e-01> : tensor<1x1x1xf32> // Input 1 inverse scale (1 / s1). %2 = stablehlo.constant dense<-128> : tensor<1x1x1xi8> // Input 1 zero point (z1). %3 = stablehlo.constant dense<4.000000e-01> : tensor<1x1x1xf32> // Input 2 inverse scale (1 / s2). - %4 = stablehlo.constant dense<-3> : tensor<1x1x1xi8> // Input 2 zero point (z2). + %4 = stablehlo.constant dense<0> : tensor<1x1x1xi8> // Input 2 zero point (z2). %5 = stablehlo.constant dense<5.000000e-01> : tensor<1x1x1xf32> // Output inverse scale (1 / s3). %6 = stablehlo.constant dense<-5> : tensor<1x1x1xi8> // Output zero point (z3). %7 = stablehlo.constant dense<1.250000e+01> : tensor<1x1x1xf32> // Merged scale (s1 * s2). @@ -516,8 +516,8 @@ module { return %23 : tensor<8x16x4xf32> } // CHECK: %[[UQ_0:.*]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<8x16x16xf32>) -> tensor<8x16x16x!quant.uniform> -// CHECK: %[[UQ_1:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<8x16x4xf32>) -> tensor<8x16x4x!quant.uniform> -// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %[[UQ_0]], %[[UQ_1]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<8x16x16x!quant.uniform>, tensor<8x16x4x!quant.uniform>) -> tensor<8x16x4x!quant.uniform> +// CHECK: %[[UQ_1:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<8x16x4xf32>) -> tensor<8x16x4x!quant.uniform> +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %[[UQ_0]], %[[UQ_1]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<8x16x16x!quant.uniform>, tensor<8x16x4x!quant.uniform>) -> tensor<8x16x4x!quant.uniform> // CHECK: %[[DQ_0:.*]] = stablehlo.uniform_dequantize %[[DOT_GENERAL]] : (tensor<8x16x4x!quant.uniform>) -> tensor<8x16x4xf32> // CHECK: return %[[DQ_0]] diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir index c614ee10bf2b45..4121caa60a8e0d 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir @@ -65,7 +65,7 @@ func.func private @XlaCallModule_aten.avg_pool2d.default.impl_0(%arg0: tensor<1x // CHECK: %[[VAL_2:.*]] = "tfl.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x3x6x6xf32>, tensor<4xi32>) -> tensor<1x6x6x3xf32> // CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : tensor<4x2xi32> // CHECK: %[[VAL_4:.*]] = "tfl.pad"(%[[VAL_2]], %[[VAL_3]]) : (tensor<1x6x6x3xf32>, tensor<4x2xi32>) -> tensor<1x6x6x3xf32> -// CHECK: %[[VAL_5:.*]] = "tfl.average_pool_2d"(%[[VAL_4]]) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x6x6x3xf32>) -> tensor<1x4x4x3xf32> +// CHECK: %[[VAL_5:.*]] = "tfl.average_pool_2d"(%[[VAL_4]]) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x6x6x3xf32>) -> tensor<1x4x4x3xf32> // CHECK: %[[VAL_6:.*]] = arith.constant dense<[0, 3, 1, 2]> : tensor<4xi32> // CHECK: %[[VAL_7:.*]] = "tfl.transpose"(%[[VAL_5]], %[[VAL_6]]) : (tensor<1x4x4x3xf32>, tensor<4xi32>) -> tensor<1x3x4x4xf32> // CHECK: %[[VAL_8:.*]] = "tf.Identity"(%[[VAL_7]]) {device = ""} : (tensor<1x3x4x4xf32>) -> tensor<*xf32> @@ -102,7 +102,7 @@ func.func private @XlaCallModule_aten.avg_pool2d.default.impl_1(%arg0: tensor<1x // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x6x6xf32>) -> tensor<*xf32> { // CHECK: %[[VAL_1:.*]] = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32> // CHECK: %[[VAL_2:.*]] = "tfl.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x3x6x6xf32>, tensor<4xi32>) -> tensor<1x6x6x3xf32> -// CHECK: %[[VAL_3:.*]] = "tfl.average_pool_2d"(%[[VAL_2]]) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x6x6x3xf32>) -> tensor<1x6x6x3xf32> +// CHECK: %[[VAL_3:.*]] = "tfl.average_pool_2d"(%[[VAL_2]]) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x6x6x3xf32>) -> tensor<1x6x6x3xf32> // CHECK: %[[VAL_4:.*]] = arith.constant dense<[0, 3, 1, 2]> : tensor<4xi32> // CHECK: %[[VAL_5:.*]] = "tfl.transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<1x6x6x3xf32>, tensor<4xi32>) -> tensor<1x3x6x6xf32> // CHECK: %[[VAL_6:.*]] = "tf.Identity"(%[[VAL_5]]) {device = ""} : (tensor<1x3x6x6xf32>) -> tensor<*xf32> @@ -172,8 +172,71 @@ func.func private @XlaCallModule_odml.upsample_bilinear2d.impl_21_0(%arg0: tenso // CHECK: %[[VAL_1:.*]] = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32> // CHECK: %[[VAL_2:.*]] = "tfl.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x64x16x16xf32>, tensor<4xi32>) -> tensor<1x16x16x64xf32> // CHECK: %[[VAL_3:.*]] = arith.constant dense<32> : tensor<2xi32> -// CHECK: %[[VAL_4:.*]] = "tfl.resize_bilinear"(%[[VAL_2]], %[[VAL_3]]) {align_corners = false, half_pixel_centers = true} : (tensor<1x16x16x64xf32>, tensor<2xi32>) -> tensor<1x32x32x64xf32> +// CHECK: %[[VAL_4:.*]] = "tfl.resize_bilinear"(%[[VAL_2]], %[[VAL_3]]) <{align_corners = false, half_pixel_centers = true}> : (tensor<1x16x16x64xf32>, tensor<2xi32>) -> tensor<1x32x32x64xf32> // CHECK: %[[VAL_5:.*]] = arith.constant dense<[0, 3, 1, 2]> : tensor<4xi32> // CHECK: %[[VAL_6:.*]] = "tfl.transpose"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x32x32x64xf32>, tensor<4xi32>) -> tensor<1x64x32x32xf32> // CHECK: return %[[VAL_6]] : tensor<1x64x32x32xf32> // CHECK: } + +func.func private @gelu_decomp(%arg0: tensor<2xf32>) -> tensor<2xf32> +func.func @gelu(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = mhlo.composite "odml.internal.gelu" %arg0 {composite_attributes = {approx = false}, decomposition = @gelu_decomp} : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: gelu +// CHECK: %0 = "tfl.gelu"(%arg0) <{approximate = false}> : (tensor<2xf32>) -> tensor<2xf32> + +// CHECK-LABEL func.func @jax_image_resize_nearest +func.func @jax_image_resize_nearest(%arg0: tensor<1x2x2x10xf32>) -> (tensor<1x4x4x10xf32>) { + %1 = mhlo.composite "odml.jax_resize_nearest_neighbor2d" %arg0 {composite_attributes = {output_size = dense<4> : tensor<2xi64>}, decomposition = @XlaCallModule_odml.jax_resize_nearest_neighbor2d.impl_0} : (tensor<1x2x2x10xf32>) -> tensor<1x4x4x10xf32> + return %1 : tensor<1x4x4x10xf32> +} +func.func private @XlaCallModule_odml.jax_resize_nearest_neighbor2d.impl_0(%arg0: tensor<1x2x2x10xf32>) -> tensor<1x4x4x10xf32> { + %0 = call @XlaCallModule__resize_0(%arg0) : (tensor<1x2x2x10xf32>) -> tensor<1x4x4x10xf32> + return %0 : tensor<1x4x4x10xf32> +} +func.func private @XlaCallModule__resize_0(%arg0: tensor<1x2x2x10xf32>) -> (tensor<1x4x4x10xf32>) { + %0 = mhlo.constant dense<2> : tensor + %1 = mhlo.constant dense<0> : tensor + %2 = mhlo.constant dense<4.000000e+00> : tensor + %3 = mhlo.constant dense<2.000000e+00> : tensor + %4 = mhlo.constant dense<5.000000e-01> : tensor + %5 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xf32> + %6 = "mhlo.broadcast_in_dim"(%4) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4xf32> + %7 = mhlo.add %5, %6 : tensor<4xf32> + %8 = "mhlo.broadcast_in_dim"(%3) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4xf32> + %9 = mhlo.multiply %7, %8 : tensor<4xf32> + %10 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4xf32> + %11 = mhlo.divide %9, %10 : tensor<4xf32> + %12 = mhlo.floor %11 : tensor<4xf32> + %13 = mhlo.convert %12 : (tensor<4xf32>) -> tensor<4xi32> + %14 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4xi32> + %15 = mhlo.compare LT, %13, %14, SIGNED : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %16 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4xi32> + %17 = mhlo.add %13, %16 : tensor<4xi32> + %18 = mhlo.select %15, %17, %13 : tensor<4xi1>, tensor<4xi32> + %19 = "mhlo.broadcast_in_dim"(%18) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xi32>) -> tensor<4x1xi32> + %20 = "mhlo.gather"(%arg0, %19) <{dimension_numbers = #mhlo.gather, slice_sizes = dense<[1, 1, 2, 10]> : tensor<4xi64>}> : (tensor<1x2x2x10xf32>, tensor<4x1xi32>) -> tensor<1x4x2x10xf32> + %21 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xf32> + %22 = "mhlo.broadcast_in_dim"(%4) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4xf32> + %23 = mhlo.add %21, %22 : tensor<4xf32> + %24 = "mhlo.broadcast_in_dim"(%3) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4xf32> + %25 = mhlo.multiply %23, %24 : tensor<4xf32> + %26 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4xf32> + %27 = mhlo.divide %25, %26 : tensor<4xf32> + %28 = mhlo.floor %27 : tensor<4xf32> + %29 = mhlo.convert %28 : (tensor<4xf32>) -> tensor<4xi32> + %30 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4xi32> + %31 = mhlo.compare LT, %29, %30, SIGNED : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %32 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4xi32> + %33 = mhlo.add %29, %32 : tensor<4xi32> + %34 = mhlo.select %31, %33, %29 : tensor<4xi1>, tensor<4xi32> + %35 = "mhlo.broadcast_in_dim"(%34) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xi32>) -> tensor<4x1xi32> + %36 = "mhlo.gather"(%20, %35) <{dimension_numbers = #mhlo.gather, slice_sizes = dense<[1, 4, 1, 10]> : tensor<4xi64>}> : (tensor<1x4x2x10xf32>, tensor<4x1xi32>) -> tensor<1x4x4x10xf32> + return %36 : tensor<1x4x4x10xf32> +} + +// CHECK: %cst = arith.constant dense<4> : tensor<2xi32> +// CHECK: %0 = "tfl.resize_nearest_neighbor"(%arg0, %cst) <{align_corners = false, half_pixel_centers = true}> : (tensor<1x2x2x10xf32>, tensor<2xi32>) -> tensor<1x4x4x10xf32> +// CHECK: return %0 : tensor<1x4x4x10xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir index 268247e815faa3..d64b50b72d533f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir @@ -5,14 +5,14 @@ module { func.func public @main(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<1x100x32x4xf32>, %arg3: tensor<1x500x4x4xf32>, %arg4: tensor<1x500x4x4xf32>, %arg5: tensor<1x1x100x500xf32>, %arg6: tensor) -> tensor<1x100x32x4xf32> { - // CHECK-ROUNDTRIP: %0 = "tfl.custom"(%arg2, %arg3, %arg4, %arg5, %arg6) {custom_code = "odml.scaled_dot_product_attention", custom_option = #tfl} : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> + // CHECK-ROUNDTRIP: %0 = "tfl.custom"(%arg2, %arg3, %arg4, %arg5, %arg6) <{custom_code = "odml.scaled_dot_product_attention", custom_option = #tfl}> : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> %0 = func.call @test_sdpa(%arg2, %arg3, %arg4, %arg5, %arg6) : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> return %0: tensor<1x100x32x4xf32> } // CHECK-LABEL: func.func private @test_sdpa func.func private @test_sdpa(%arg0: tensor<1x100x32x4xf32>, %arg1: tensor<1x500x4x4xf32>, %arg2: tensor<1x500x4x4xf32>, %arg3: tensor<1x1x100x500xf32>, %arg4: tensor) -> tensor<1x100x32x4xf32> { - // CHECK: %0 = "tfl.custom"(%arg0, %arg1, %arg2, %arg3, %arg4) {custom_code = "odml.scaled_dot_product_attention", custom_option = #tfl} : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> + // CHECK: %0 = "tfl.custom"(%arg0, %arg1, %arg2, %arg3, %arg4) <{custom_code = "odml.scaled_dot_product_attention", custom_option = #tfl}> : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> %0 = stablehlo.composite "odml.scaled_dot_product_attention" %arg0, %arg1, %arg2, %arg3, %arg4 {decomposition = @odml.scaled_dot_product_attention.impl} : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> return %0 : tensor<1x100x32x4xf32> } @@ -23,8 +23,8 @@ module { // CHECK-LABEL: func.func private @test_multiple_kv_caches func.func private @test_multiple_kv_caches(%arg0: tensor<1x500x4x4xf32>, %arg1: tensor<1x500x4x4xf32>, %arg2: tensor<100xi64>, %arg3: tensor<1x100x4x4xf32>, %arg4: tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) { - // CHECK: %0:2 = "tfl.custom"(%arg2, %arg3, %arg4) {custom_code = "odml.update_kv_cache", custom_option = #tfl} : (tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) - // CHECK: %1:2 = "tfl.custom"(%arg2, %arg3, %arg4) {custom_code = "odml.update_kv_cache", custom_option = #tfl} : (tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) + // CHECK: %0:2 = "tfl.custom"(%arg2, %arg3, %arg4) <{custom_code = "odml.update_kv_cache", custom_option = #tfl}> : (tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) + // CHECK: %1:2 = "tfl.custom"(%arg2, %arg3, %arg4) <{custom_code = "odml.update_kv_cache", custom_option = #tfl}> : (tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) %0:2 = stablehlo.composite "odml.update_kv_cache" %arg0, %arg1, %arg2, %arg3, %arg4 {composite_attributes = {kv_cache_max = 500 : i64}, decomposition = @odml.update_kv_cache.impl_0} : (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) %1:2 = stablehlo.composite "odml.update_kv_cache" %0#0, %0#1, %arg2, %arg3, %arg4 {composite_attributes = {kv_cache_max = 500 : i64}, decomposition = @odml.update_kv_cache.impl_0} : (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) return %1#0, %1#1 : tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize_layout.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize_layout.mlir new file mode 100644 index 00000000000000..25ae45f300b13e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize_layout.mlir @@ -0,0 +1,50 @@ +// RUN: odml-to-stablehlo-opt %s --transpose-commute-ops | FileCheck %s +// CHECK-LABEL: func.func @commute_transpose_pad( +// CHECK-SAME: %[[INPUT:.*]]: tensor<1x112x112x64xf32>, +// CHECK-SAME: %[[PAD_VAL:.*]]: tensor) -> tensor<1x64x114x114xf32> { +// CHECK: %[[PAD:.*]] = stablehlo.pad %[[INPUT]], %[[PAD_VAL]], +// CHECK: low = [0, 1, 1, 0], high = [0, 1, 1, 0], interior = [0, 0, 0, 0] +// CHECK: : (tensor<1x112x112x64xf32>, tensor) -> tensor<1x114x114x64xf32> +// CHECK: %[[TPOS:.*]] = stablehlo.transpose %[[PAD]], dims = [0, 3, 1, 2] +// CHECK: : (tensor<1x114x114x64xf32>) -> tensor<1x64x114x114xf32> +// CHECK: return %[[TPOS]] : tensor<1x64x114x114xf32> + +func.func @commute_transpose_pad( + %arg0: tensor<1x112x112x64xf32>, %padding_val: tensor) + -> tensor<1x64x114x114xf32> { + %tspos = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] + : (tensor<1x112x112x64xf32>) -> tensor<1x64x112x112xf32> + %ret = stablehlo.pad %tspos, %padding_val, + low = [0, 0, 1, 1], high = [0, 0, 1, 1], interior = [0, 0, 0, 0] + : (tensor<1x64x112x112xf32>, tensor) -> tensor<1x64x114x114xf32> + return %ret :tensor<1x64x114x114xf32> +} + +// ----- +// CHECK-LABEL: func.func @commute_transpose_reduce_window( +// CHECK-SAME: %[[INPUT:.*]]: tensor<1x114x114x64xf32>, +// CHECK-SAME: %[[PAD_VAL:.*]]: tensor) -> tensor<1x64x56x56xf32> { +// CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[INPUT]], %[[PAD_VAL]]) +// CHECK: <{window_dimensions = array, +// CHECK: window_strides = array}> ({ +// CHECK: ^bb0(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor): +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG0]], %[[ARG1]] : tensor +// CHECK: stablehlo.return %[[MAX]] : tensor +// CHECK: }) : (tensor<1x114x114x64xf32>, tensor) -> tensor<1x56x56x64xf32> +// CHECK: %[[TPOS:.*]] = stablehlo.transpose %[[REDUCE]], dims = [0, 3, 1, 2] +// CHECK: : (tensor<1x56x56x64xf32>) -> tensor<1x64x56x56xf32> +// CHECK: return %[[TPOS]] : tensor<1x64x56x56xf32> + +func.func @commute_transpose_reduce_window( + %input: tensor<1x114x114x64xf32>, + %cst: tensor) -> tensor<1x64x56x56xf32> { + %tpos = stablehlo.transpose %input, dims = [0, 3, 1, 2] + : (tensor<1x114x114x64xf32>) -> tensor<1x64x114x114xf32> + %ret = "stablehlo.reduce_window"(%tpos, %cst) + <{window_dimensions = array, window_strides = array}> ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %max = stablehlo.maximum %arg0, %arg1 : tensor + stablehlo.return %max: tensor + }) : (tensor<1x64x114x114xf32>, tensor) -> tensor<1x64x56x56xf32> + return %ret : tensor<1x64x56x56xf32> +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir index 9be635a44268f6..9b1c3f91ebb4e7 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir @@ -15,7 +15,7 @@ func.func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { func.return %0 : tensor<3x2xf32> // CHECK-LABEL: transpose_2d -// CHECK-NEXT: %0 = "tfl.pseudo_const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK-NEXT: %0 = "tfl.pseudo_const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64> // CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<2xi64>) -> tensor<2xi32> // CHECK-NEXT: %2 = "tfl.transpose"(%arg0, %1) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> // CHECK-NEXT: return %2 : tensor<3x2xf32> @@ -26,7 +26,7 @@ func.func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { func.return %0 : tensor<3x2x1xf32> // CHECK-LABEL: transpose_3d -// CHECK-NEXT: %0 = "tfl.pseudo_const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK-NEXT: %0 = "tfl.pseudo_const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64> // CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<3xi64>) -> tensor<3xi32> // CHECK-NEXT: %2 = "tfl.transpose"(%arg0, %1) : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x2x1xf32> // CHECK-NEXT: return %2 : tensor<3x2x1xf32> @@ -37,7 +37,7 @@ func.func @transpose_dynamic_2d(%arg0: tensor) -> tensor<4x?xf32> { func.return %0 : tensor<4x?xf32> // CHECK-LABEL: transpose_dynamic_2d -// CHECK-NEXT: %0 = "tfl.pseudo_const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK-NEXT: %0 = "tfl.pseudo_const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64> // CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<2xi64>) -> tensor<2xi32> // CHECK-NEXT: %2 = "tfl.transpose"(%arg0, %1) : (tensor, tensor<2xi32>) -> tensor<4x?xf32> // CHECK-NEXT: return %2 : tensor<4x?xf32> @@ -63,7 +63,7 @@ func.func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4 // CHECK: %[[TRANSPOSED_1:.*]] = "tfl.transpose" // CHECK-NEXT: %[[RESHAPED_0:.*]] = mhlo.reshape %[[TRANSPOSED_0]] // CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %[[TRANSPOSED_1]] -// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> +// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> // CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]] // CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<3x5x1x4xf32> } @@ -84,7 +84,7 @@ func.func @convert_dot_general_repeated(%arg0: tensor<1x1x1024xf32>, %arg1: tens // CHECK-LABEL: convert_dot_general_repeated // CHECK: %[[RESHAPED_0:.*]] = mhlo.reshape %arg0 // CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %arg1 -// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : {{.*}} -> tensor<1x1024xf32> +// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x1024xf32> // CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]] // CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<1x1x1024xf32> } @@ -101,7 +101,7 @@ func.func @convert_dot_general_int8(%arg0: tensor<256xi8>, %arg1: tensor<256x8xi // CHECK-LABEL: convert_dot_general_int8 // CHECK: %[[RESHAPED_0:.*]] = mhlo.reshape %arg0 // CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %arg1 -// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : {{.*}} -> tensor<1x8xi32> +// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x8xi32> // CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]] // CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<8xi32> } @@ -117,27 +117,27 @@ func.func @convert_dot_general_dynamic_rhs_out_dim(%arg0: tensor<4x4x256xf32>, % func.return %0 : tensor<4x4x?xf32> // CHECK-LABEL: convert_dot_general_dynamic_rhs_out_dim -// CHECK: %0 = "tfl.pseudo_const"() {value = dense<[0, 2, 1]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> // CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<3xi64>) -> tensor<3xi32> // CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x256x?xf32> // CHECK-NEXT: %3 = mhlo.reshape %arg0 : (tensor<4x4x256xf32>) -> tensor<4x4x256xf32> // CHECK-NEXT: %4 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> -// CHECK-NEXT: %5 = "tfl.pseudo_const"() {value = dense<[-1, 0, -1]> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK-NEXT: %6 = "tfl.pseudo_const"() {value = dense<[-1, -1, 0]> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK-NEXT: %7 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor +// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-NEXT: %7 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor // CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%4, %5, %7) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> // CHECK-NEXT: %9 = "tfl.unsorted_segment_prod"(%4, %6, %7) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %10 = "tfl.pseudo_const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32> -// CHECK-NEXT: %11 = "tfl.concatenation"(%10, %9, %8) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK-NEXT: %10 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-NEXT: %11 = "tfl.concatenation"(%10, %9, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> // CHECK-NEXT: %12 = mhlo.dynamic_reshape %2, %11 : (tensor<4x256x?xf32>, tensor<3xi32>) -> tensor<4x256x?xf32> -// CHECK-NEXT: %13 = "tfl.batch_matmul"(%3, %12) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32> +// CHECK-NEXT: %13 = "tfl.batch_matmul"(%3, %12) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32> // CHECK-NEXT: %14 = "tfl.shape"(%arg0) : (tensor<4x4x256xf32>) -> tensor<3xi32> // CHECK-NEXT: %15 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> -// CHECK-NEXT: %16 = "tfl.pseudo_const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK-NEXT: %17 = "tfl.gather"(%14, %16) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<3xi32>, tensor<2xi64>) -> tensor<2xi32> -// CHECK-NEXT: %18 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> -// CHECK-NEXT: %19 = "tfl.gather"(%15, %18) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<3xi32>, tensor<1xi64>) -> tensor<1xi32> -// CHECK-NEXT: %20 = "tfl.concatenation"(%17, %19) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK-NEXT: %16 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> +// CHECK-NEXT: %17 = "tfl.gather"(%14, %16) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<2xi64>) -> tensor<2xi32> +// CHECK-NEXT: %18 = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64> +// CHECK-NEXT: %19 = "tfl.gather"(%15, %18) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<1xi64>) -> tensor<1xi32> +// CHECK-NEXT: %20 = "tfl.concatenation"(%17, %19) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32> // CHECK-NEXT: %21 = mhlo.dynamic_reshape %13, %20 : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32> // CHECK-NEXT: return %21 : tensor<4x4x?xf32> } @@ -153,37 +153,37 @@ func.func @convert_dot_general_dynamic_batch_dim(%arg0: tensor<2x?x2x3xf32>, %ar func.return %0 : tensor<2x?x2x4xf32> // CHECK-LABEL: convert_dot_general_dynamic_batch_dim -// CHECK: %0 = "tfl.pseudo_const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> // CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32> // CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x?x4x3xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32> // CHECK-NEXT: %3 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %4 = "tfl.pseudo_const"() {value = dense<[-1, -1, 0, -1]> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK-NEXT: %5 = "tfl.pseudo_const"() {value = dense<[-1, -1, -1, 0]> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK-NEXT: %6 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor +// CHECK-NEXT: %4 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor // CHECK-NEXT: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> // CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %9 = "tfl.pseudo_const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK-NEXT: %10 = "tfl.gather"(%3, %9) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> -// CHECK-NEXT: %11 = "tfl.concatenation"(%10, %7, %8) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-NEXT: %9 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> +// CHECK-NEXT: %10 = "tfl.gather"(%3, %9) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> +// CHECK-NEXT: %11 = "tfl.concatenation"(%10, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> // CHECK-NEXT: %12 = mhlo.dynamic_reshape %arg0, %11 : (tensor<2x?x2x3xf32>, tensor<4xi32>) -> tensor<2x?x2x3xf32> // CHECK-NEXT: %13 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %14 = "tfl.pseudo_const"() {value = dense<[-1, -1, 0, -1]> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK-NEXT: %15 = "tfl.pseudo_const"() {value = dense<[-1, -1, -1, 0]> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK-NEXT: %16 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor +// CHECK-NEXT: %14 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %16 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor // CHECK-NEXT: %17 = "tfl.unsorted_segment_prod"(%13, %14, %16) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> // CHECK-NEXT: %18 = "tfl.unsorted_segment_prod"(%13, %15, %16) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %19 = "tfl.pseudo_const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK-NEXT: %20 = "tfl.gather"(%13, %19) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> -// CHECK-NEXT: %21 = "tfl.concatenation"(%20, %18, %17) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-NEXT: %19 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> +// CHECK-NEXT: %20 = "tfl.gather"(%13, %19) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> +// CHECK-NEXT: %21 = "tfl.concatenation"(%20, %18, %17) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> // CHECK-NEXT: %22 = mhlo.dynamic_reshape %2, %21 : (tensor<2x?x3x4xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32> -// CHECK-NEXT: %23 = "tfl.batch_matmul"(%12, %22) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32> +// CHECK-NEXT: %23 = "tfl.batch_matmul"(%12, %22) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32> // CHECK-NEXT: %24 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32> // CHECK-NEXT: %25 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %26 = "tfl.pseudo_const"() {value = dense<[0, 1, 2]> : tensor<3xi64>} : () -> tensor<3xi64> -// CHECK-NEXT: %27 = "tfl.gather"(%24, %26) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32> -// CHECK-NEXT: %28 = "tfl.pseudo_const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64> -// CHECK-NEXT: %29 = "tfl.gather"(%25, %28) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<4xi32>, tensor<1xi64>) -> tensor<1xi32> -// CHECK-NEXT: %30 = "tfl.concatenation"(%27, %29) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-NEXT: %26 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK-NEXT: %27 = "tfl.gather"(%24, %26) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32> +// CHECK-NEXT: %28 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64> +// CHECK-NEXT: %29 = "tfl.gather"(%25, %28) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<1xi64>) -> tensor<1xi32> +// CHECK-NEXT: %30 = "tfl.concatenation"(%27, %29) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<1xi32>) -> tensor<4xi32> // CHECK-NEXT: %31 = mhlo.dynamic_reshape %23, %30 : (tensor<2x?x2x4xf32>, tensor<4xi32>) -> tensor<2x?x2x4xf32> // CHECK-NEXT: return %31 : tensor<2x?x2x4xf32> } @@ -200,35 +200,35 @@ func.func @convert_dot_general_dynamic_lhs_rhs_out_dims(%arg0: tensor<2x2x?x3xf3 func.return %0 : tensor<2x2x?x4x?xf32> // CHECK-LABEL: convert_dot_general_dynamic_lhs_rhs_out_dims -// CHECK: %0 = "tfl.pseudo_const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> // CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32> // CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x4x?x3xf32>, tensor<4xi32>) -> tensor<2x3x4x?xf32> // CHECK-NEXT: %3 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %4 = "tfl.pseudo_const"() {value = dense<[-1, 0, 0, -1]> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK-NEXT: %5 = "tfl.pseudo_const"() {value = dense<[-1, -1, -1, 0]> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK-NEXT: %6 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor +// CHECK-NEXT: %4 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor // CHECK-NEXT: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> // CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %9 = "tfl.pseudo_const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> -// CHECK-NEXT: %10 = "tfl.concatenation"(%9, %7, %8) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK-NEXT: %9 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-NEXT: %10 = "tfl.concatenation"(%9, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> // CHECK-NEXT: %11 = mhlo.dynamic_reshape %arg0, %10 : (tensor<2x2x?x3xf32>, tensor<3xi32>) -> tensor<2x?x3xf32> // CHECK-NEXT: %12 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %13 = "tfl.pseudo_const"() {value = dense<[-1, 0, 0, -1]> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK-NEXT: %14 = "tfl.pseudo_const"() {value = dense<[-1, -1, -1, 0]> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK-NEXT: %15 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor +// CHECK-NEXT: %13 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %14 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor // CHECK-NEXT: %16 = "tfl.unsorted_segment_prod"(%12, %13, %15) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> // CHECK-NEXT: %17 = "tfl.unsorted_segment_prod"(%12, %14, %15) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %18 = "tfl.pseudo_const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> -// CHECK-NEXT: %19 = "tfl.concatenation"(%18, %17, %16) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK-NEXT: %18 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-NEXT: %19 = "tfl.concatenation"(%18, %17, %16) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> // CHECK-NEXT: %20 = mhlo.dynamic_reshape %2, %19 : (tensor<2x3x4x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32> -// CHECK-NEXT: %21 = "tfl.batch_matmul"(%11, %20) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32> +// CHECK-NEXT: %21 = "tfl.batch_matmul"(%11, %20) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32> // CHECK-NEXT: %22 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32> // CHECK-NEXT: %23 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %24 = "tfl.pseudo_const"() {value = dense<[0, 1, 2]> : tensor<3xi64>} : () -> tensor<3xi64> -// CHECK-NEXT: %25 = "tfl.gather"(%22, %24) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32> -// CHECK-NEXT: %26 = "tfl.pseudo_const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK-NEXT: %27 = "tfl.gather"(%23, %26) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> -// CHECK-NEXT: %28 = "tfl.concatenation"(%25, %27) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<2xi32>) -> tensor<5xi32> +// CHECK-NEXT: %24 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK-NEXT: %25 = "tfl.gather"(%22, %24) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32> +// CHECK-NEXT: %26 = "tfl.pseudo_const"() <{value = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64> +// CHECK-NEXT: %27 = "tfl.gather"(%23, %26) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> +// CHECK-NEXT: %28 = "tfl.concatenation"(%25, %27) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<2xi32>) -> tensor<5xi32> // CHECK-NEXT: %29 = mhlo.dynamic_reshape %21, %28 : (tensor<2x?x?xf32>, tensor<5xi32>) -> tensor<2x2x?x4x?xf32> // CHECK-NEXT: return %29 : tensor<2x2x?x4x?xf32> @@ -246,24 +246,24 @@ func.return %0 : tensor<4x4x256xf32> // CHECK-LABEL: convert_dot_general_dynamic_contracting_dim // CHECK: %0 = "tfl.shape"(%arg0) : (tensor<4x4x?xf32>) -> tensor<3xi32> -// CHECK-NEXT: %1 = "tfl.pseudo_const"() {value = dense<[-1, 0, -1]> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK-NEXT: %2 = "tfl.pseudo_const"() {value = dense<[-1, -1, 0]> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK-NEXT: %3 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor +// CHECK-NEXT: %1 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-NEXT: %2 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-NEXT: %3 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor // CHECK-NEXT: %4 = "tfl.unsorted_segment_prod"(%0, %1, %3) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> // CHECK-NEXT: %5 = "tfl.unsorted_segment_prod"(%0, %2, %3) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %6 = "tfl.pseudo_const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32> -// CHECK-NEXT: %7 = "tfl.concatenation"(%6, %4, %5) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-NEXT: %7 = "tfl.concatenation"(%6, %4, %5) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> // CHECK-NEXT: %8 = mhlo.dynamic_reshape %arg0, %7 : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32> // CHECK-NEXT: %9 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> -// CHECK-NEXT: %10 = "tfl.pseudo_const"() {value = dense<[-1, -1, 0]> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK-NEXT: %11 = "tfl.pseudo_const"() {value = dense<[-1, 0, -1]> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK-NEXT: %12 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor +// CHECK-NEXT: %10 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-NEXT: %11 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-NEXT: %12 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor // CHECK-NEXT: %13 = "tfl.unsorted_segment_prod"(%9, %10, %12) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> // CHECK-NEXT: %14 = "tfl.unsorted_segment_prod"(%9, %11, %12) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %15 = "tfl.pseudo_const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32> -// CHECK-NEXT: %16 = "tfl.concatenation"(%15, %14, %13) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-NEXT: %16 = "tfl.concatenation"(%15, %14, %13) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> // CHECK-NEXT: %17 = mhlo.dynamic_reshape %arg1, %16 : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x?x256xf32> -// CHECK-NEXT: %18 = "tfl.batch_matmul"(%8, %17) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32> +// CHECK-NEXT: %18 = "tfl.batch_matmul"(%8, %17) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32> // CHECK-NEXT: %19 = mhlo.reshape %18 : (tensor<4x4x256xf32>) -> tensor<4x4x256xf32> // CHECK-NEXT: return %19 : tensor<4x4x256xf32> } @@ -294,7 +294,7 @@ func.func @convert_argmax(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, ten // CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32> // CHECK: %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32> // CHECK: %cst = arith.constant dense<2> : tensor<1xi32> - // CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = false} : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xf32> + // CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xf32> // CHECK: %5 = "tfl.arg_max"(%arg0, %cst) : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xi32> // CHECK: return %4, %5 : tensor<4x32xf32>, tensor<4x32xi32> } @@ -323,7 +323,7 @@ func.func @convert_argmax_constant(%arg0: tensor<2x2x4xf32>) -> (tensor<2x2xf32> // CHECK-DAG: %1 = mhlo.constant dense<0> : tensor // CHECK: %2 = mhlo.constant dense<{{\[\[}}[0, 1, 2, 3], [0, 1, 2, 3]], {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x2x4xi32> // CHECK: %cst = arith.constant dense<2> : tensor<1xi32> - // CHECK: %3 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = false} : (tensor<2x2x4xf32>, tensor<1xi32>) -> tensor<2x2xf32> + // CHECK: %3 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<2x2x4xf32>, tensor<1xi32>) -> tensor<2x2xf32> // CHECK: %4 = "tfl.arg_max"(%arg0, %cst) : (tensor<2x2x4xf32>, tensor<1xi32>) -> tensor<2x2xi32> // CHECK: return %3, %4 : tensor<2x2xf32>, tensor<2x2xi32> } @@ -352,7 +352,7 @@ func.func @convert_argmax_constant_non_z_axis(%arg0: tensor<4x4xf32>) -> (tensor // CHECK-DAG: %1 = mhlo.constant dense<0> : tensor // CHECK: %2 = mhlo.constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]> : tensor<4x4xi32> // CHECK: %cst = arith.constant dense<0> : tensor<1xi32> - // CHECK: %3 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = false} : (tensor<4x4xf32>, tensor<1xi32>) -> tensor<4xf32> + // CHECK: %3 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<4x4xf32>, tensor<1xi32>) -> tensor<4xf32> // CHECK: %4 = "tfl.arg_max"(%arg0, %cst) : (tensor<4x4xf32>, tensor<1xi32>) -> tensor<4xi32> // CHECK: return %3, %4 : tensor<4xf32>, tensor<4xi32> } @@ -379,7 +379,7 @@ func.func @convert_argmax_bool(%arg0: tensor<2xi1>) -> tensor { // CHECK-DAG: %1 = mhlo.constant dense : tensor // CHECK: %2 = mhlo.constant dense<0> : tensor // CHECK: %cst = arith.constant dense<0> : tensor<1xi32> - // CHECK: %3 = "tfl.reduce_any"(%arg0, %cst) {keep_dims = false} : (tensor<2xi1>, tensor<1xi32>) -> tensor + // CHECK: %3 = "tfl.reduce_any"(%arg0, %cst) <{keep_dims = false}> : (tensor<2xi1>, tensor<1xi32>) -> tensor // CHECK: %4 = "tfl.arg_max"(%arg0, %cst) : (tensor<2xi1>, tensor<1xi32>) -> tensor // CHECK: return %4 : tensor } @@ -410,7 +410,7 @@ func.func @convert_argmin(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, ten // CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32> // CHECK: %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32> // CHECK: %cst = arith.constant dense<2> : tensor<1xi32> - // CHECK: %4 = "tfl.reduce_min"(%arg0, %cst) {keep_dims = false} : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xf32> + // CHECK: %4 = "tfl.reduce_min"(%arg0, %cst) <{keep_dims = false}> : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xf32> // CHECK: %5 = "tfl.arg_min"(%arg0, %cst) : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xi32> // CHECK: return %4, %5 : tensor<4x32xf32>, tensor<4x32xi32> } @@ -440,7 +440,7 @@ func.func @convert_argmin_i16(%arg0: tensor<2xi16>) -> (tensor, tensor // CHECK-DAG: %2 = mhlo.constant dense<32767> : tensor // CHECK: %3 = mhlo.constant dense<0> : tensor // CHECK: %cst = arith.constant dense<0> : tensor<1xi32> - // CHECK: %4 = "tfl.reduce_min"(%arg0, %cst) {keep_dims = false} : (tensor<2xi16>, tensor<1xi32>) -> tensor + // CHECK: %4 = "tfl.reduce_min"(%arg0, %cst) <{keep_dims = false}> : (tensor<2xi16>, tensor<1xi32>) -> tensor // CHECK: %5 = "tfl.arg_min"(%arg0, %cst) : (tensor<2xi16>, tensor<1xi32>) -> tensor // CHECK: return %4, %5 : tensor, tensor } @@ -470,7 +470,7 @@ func.func @convert_argmin_constant(%arg0: tensor<2x2x4xf32>) -> (tensor<2x2xf32> // CHECK-DAG: %1 = mhlo.constant dense<0> : tensor // CHECK: %2 = mhlo.constant dense<{{\[\[}}[0, 1, 2, 3], [0, 1, 2, 3]], {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x2x4xi32> // CHECK: %cst = arith.constant dense<2> : tensor<1xi32> - // CHECK: %3 = "tfl.reduce_min"(%arg0, %cst) {keep_dims = false} : (tensor<2x2x4xf32>, tensor<1xi32>) -> tensor<2x2xf32> + // CHECK: %3 = "tfl.reduce_min"(%arg0, %cst) <{keep_dims = false}> : (tensor<2x2x4xf32>, tensor<1xi32>) -> tensor<2x2xf32> // CHECK: %4 = "tfl.arg_min"(%arg0, %cst) : (tensor<2x2x4xf32>, tensor<1xi32>) -> tensor<2x2xi32> // CHECK: return %3, %4 : tensor<2x2xf32>, tensor<2x2xi32> } @@ -497,7 +497,7 @@ func.func @convert_argmin_bool(%arg0: tensor<2xi1>) -> tensor { // CHECK-DAG: %1 = mhlo.constant dense : tensor // CHECK: %2 = mhlo.constant dense<0> : tensor // CHECK: %cst = arith.constant dense<0> : tensor<1xi32> - // CHECK: %3 = "tfl.reduce_all"(%arg0, %cst) {keep_dims = false} : (tensor<2xi1>, tensor<1xi32>) -> tensor + // CHECK: %3 = "tfl.reduce_all"(%arg0, %cst) <{keep_dims = false}> : (tensor<2xi1>, tensor<1xi32>) -> tensor // CHECK: %4 = "tfl.arg_min"(%arg0, %cst) : (tensor<2xi1>, tensor<1xi32>) -> tensor // CHECK: return %4 : tensor } @@ -528,7 +528,7 @@ func.func @convert_argmax_with_reshaped_iota(%arg0: tensor<1x32x1xf32>) -> (tens // CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<32xi32> // CHECK: %3 = mhlo.reshape %2 : (tensor<32xi32>) -> tensor<1x32x1xi32> // CHECK: %cst = arith.constant dense<1> : tensor<1xi32> - // CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = false} : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xf32> + // CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xf32> // CHECK: %5 = "tfl.arg_max"(%arg0, %cst) : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xi32> // CHECK: return %4, %5 : tensor<1x1xf32>, tensor<1x1xi32> } @@ -556,7 +556,7 @@ func.func @convert_pytorch_argmax(%arg0: tensor<1x9xi32>) -> tensor<1xi32> { // CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<9xi32> // CHECK: %3 = mhlo.reshape %2 : (tensor<9xi32>) -> tensor<1x9xi32> // CHECK: %cst = arith.constant dense<1> : tensor<1xi32> - // CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = false} : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: %5 = "tfl.arg_max"(%arg0, %cst) : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: return %5 : tensor<1xi32> } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo_custom_call.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo_custom_call.mlir index f0cba06bd984d8..8b1b24e9888508 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo_custom_call.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo_custom_call.mlir @@ -8,10 +8,10 @@ func.func @mhlo_custom_call_test__legalize_string_backend_config(%arg0: tensor<1 } : (tensor<1x4xf32>) -> (tensor<1x8xf32>) func.return %0 : tensor<1x8xf32> - // CHECK: %0 = "tfl.custom"(%arg0) { + // CHECK: %0 = "tfl.custom"(%arg0) <{ // CHECK-SAME: custom_code = "custom_call.my_custom_op", // CHECK-SAME: custom_option = #tfl - // CHECK-SAME: } : (tensor<1x4xf32>) -> tensor<1x8xf32> + // CHECK-SAME: }> : (tensor<1x4xf32>) -> tensor<1x8xf32> } // CHECK-LABEL: mhlo_custom_call_test__dont_legalize_dict_backend_config @@ -35,10 +35,10 @@ func.func @mhlo_custom_call_test__api_version_4(%arg0: tensor<1x4xf32>) -> tenso } : (tensor<1x4xf32>) -> (tensor<1x8xf32>) func.return %0 : tensor<1x8xf32> - // CHECK: %0 = "tfl.custom"(%arg0) { + // CHECK: %0 = "tfl.custom"(%arg0) <{ // CHECK-SAME: custom_code = "custom_call.my_custom_op", // CHECK-SAME: custom_option = #tfl - // CHECK-SAME: } : (tensor<1x4xf32>) -> tensor<1x8xf32> + // CHECK-SAME: }> : (tensor<1x4xf32>) -> tensor<1x8xf32> } // CHECK-LABEL: mhlo_custom_call_does_not_legalize_tf_function diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir index 7107f7dcb08a45..64b14b85fc7c71 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir @@ -13,7 +13,7 @@ func.func @uniform_quantize_op(%arg: tensor<2x2xf32>) -> tensor<2x2x!quant.unifo return %0 : tensor<2x2x!quant.uniform> } // CHECK-LABEL: uniform_quantize_op -// CHECK: %[[QUANT:.+]] = "tfl.quantize"({{.*}}) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> +// CHECK: %[[QUANT:.+]] = "tfl.quantize"({{.*}}) <{qtype = tensor<2x2x!quant.uniform>}> : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> // CHECK: return %[[QUANT]] // ----- @@ -120,9 +120,9 @@ func.func @convolution_upstream_same_padding_srq(%arg0: tensor<1x3x3x4x!quant.un // to (2, 3, 3, 4). // CHECK-LABEL: convolution_upstream_same_padding_srq // CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> -// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> -// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> -// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>}> : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> // CHECK: return %[[CONV2D]] : tensor<1x3x3x2x!quant.uniform> // ----- @@ -149,10 +149,10 @@ func.func @convolution_upstream_srq_valid_padding(%arg0: tensor<1x3x3x4x!quant.u } // CHECK-LABEL: convolution_upstream_srq_valid_padding // CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> -// CHECK: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>}> : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> // CHECK-NOT: tfl.pad -// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> +// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> // CHECK: return %[[CONV2D]] : tensor<1x1x1x2x!quant.uniform> // ----- @@ -168,9 +168,9 @@ func.func @convolution_upstream_srq_valid_padding(%arg0: tensor<1x3x3x4x!quant.u } // CHECK-LABEL: convolution_upstream_srq_valid_padding // CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> -// CHECK: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> -// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>}> : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> // CHECK: return %[[CONV2D]] : tensor<1x1x1x2x!quant.uniform> // ----- @@ -186,10 +186,10 @@ func.func @convolution_upstream_srq_strides(%arg0: tensor<1x3x3x4x!quant.uniform } // CHECK-LABEL: convolution_upstream_srq_strides // CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> -// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> -// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>}> : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> // Tests that the stride_w is set to 2. -// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> +// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32}> : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> // CHECK: return %[[CONV2D]] : tensor<1x3x2x2x!quant.uniform> // ----- @@ -210,8 +210,8 @@ func.func @dot_general_upstream_srq_asym_input(%arg0: tensor<1x2x3x4x!quant.unif } // CHECK-LABEL: dot_general_upstream_srq_asym_input // CHECK-SAME: %[[ARG:.+]]: tensor<1x2x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> -// CHECK: %[[BMM:.+]] = "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>}> : () -> tensor<1x2x4x5x!quant.uniform> +// CHECK: %[[BMM:.+]] = "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) <{adj_x = false, adj_y = false}> : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> // ----- @@ -233,7 +233,7 @@ func.func @dot_general_upstream_srq_sym_input(%arg0: tensor<1x2x3x4x!quant.unifo // CHECK-LABEL: dot_general_upstream_srq_sym_input // CHECK-SAME: %[[ARG:.+]]: tensor<1x2x3x4x!quant.uniform> // CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() -// CHECK: "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = false} +// CHECK: "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) <{adj_x = false, adj_y = false}> // ----- @@ -252,7 +252,7 @@ func.func @dot_general_upstream_srq_activation_rhs(%arg0: tensor<1x2x3x4x!quant. return %0 : tensor<1x2x3x5x!quant.uniform> } // CHECK-LABEL: dot_general_upstream_srq_activation_rhs -// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> +// CHECK: "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> // ----- @@ -274,8 +274,8 @@ func.func @dot_general_upstream_srq_adj_x(%arg0: tensor<1x2x4x3x!quant.uniform> } // CHECK-SAME: %[[ARG:.+]]: tensor<1x2x4x3x!quant.uniform> -// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> -// CHECK: "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = true, adj_y = false} +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>}> : () -> tensor<1x2x4x5x!quant.uniform> +// CHECK: "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) <{adj_x = true, adj_y = false}> // ----- @@ -297,8 +297,8 @@ func.func @dot_general_upstream_srq_adj_y(%arg0: tensor<1x2x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x5x4x!quant.uniform>, value = dense<1> : tensor<1x2x5x4xi8>} : () -> tensor<1x2x5x4x!quant.uniform> -// CHECK: "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = true} +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x2x5x4x!quant.uniform>, value = dense<1> : tensor<1x2x5x4xi8>}> : () -> tensor<1x2x5x4x!quant.uniform> +// CHECK: "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) <{adj_x = false, adj_y = true}> // ----- @@ -393,13 +393,13 @@ func.func @dot_general_upstream_srq_float_operands(%arg0: tensor<1x2x3x4xf32>, % // CHECK-LABEL: dot_general_upstream_srq_asym_weight func.func @dot_general_upstream_srq_asym_weight(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { - %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> - %1 = "stablehlo.dot_general"(%arg0, %0) {dot_dimension_numbers = #stablehlo.dot, precision_config = [#stablehlo, #stablehlo]} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> + %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> + %1 = "stablehlo.dot_general"(%arg0, %0) {dot_dimension_numbers = #stablehlo.dot, precision_config = [#stablehlo, #stablehlo]} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %1 : tensor<1x2x3x5x!quant.uniform> } // CHECK-SAME: %[[ARG:.+]]: tensor<1x2x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> -// CHECK: %[[BMM:.+]] = "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>}> : () -> tensor<1x2x4x5x!quant.uniform> +// CHECK: %[[BMM:.+]] = "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) <{adj_x = false, adj_y = false}> : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> // ----- @@ -414,10 +414,10 @@ func.func @dot_general_upstream_srq_per_axis_quantized_filter(%arg0: tensor<1x3x } // CHECK-SAME: %[[ARG_0:.+]]: tensor<1x3x!quant.uniform> // Weight tensor is transposed, as tfl.fully_connected accepts a [o, i] matrix. -// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<1> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> -// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<1> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> // Bias tensor's scale is input scale * filter scale. -// CHECK: %[[FC:.+]] = "tfl.fully_connected"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x3x!quant.uniform>, tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[FC:.+]] = "tfl.fully_connected"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x3x!quant.uniform>, tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x2x!quant.uniform> // CHECK-NEXT: return %[[FC]] : tensor<1x2x!quant.uniform> // ----- @@ -427,8 +427,8 @@ func.func @dot_general_upstream_srq_per_axis_quantized_filter(%arg0: tensor<1x3x // CHECK-LABEL: dot_general_upstream_srq_per_axis_quantized_filter_with_batch_dim func.func @dot_general_upstream_srq_per_axis_quantized_filter_with_batch_dim(%arg0: tensor<1x1x3x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> { - %0 = stablehlo.constant() {value = dense<1> : tensor<1x3x2xi8>} : () -> tensor<1x3x2x!quant.uniform> - %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<1x1x3x!quant.uniform>, tensor<1x3x2x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> + %0 = stablehlo.constant() {value = dense<1> : tensor<1x3x2xi8>} : () -> tensor<1x3x2x!quant.uniform> + %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<1x1x3x!quant.uniform>, tensor<1x3x2x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> return %1 : tensor<1x1x2x!quant.uniform> } // Nothing changes. @@ -443,8 +443,8 @@ func.func @dot_general_upstream_srq_per_axis_quantized_filter_with_batch_dim(%ar // CHECK-LABEL: dot_general_upstream_srq_per_axis_quantized_filter_multibatch func.func @dot_general_upstream_srq_per_axis_quantized_filter_multibatch(%arg0: tensor<3x1x3x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> { - %0 = stablehlo.constant() {value = dense<1> : tensor<3x3x2xi8>} : () -> tensor<3x3x2x!quant.uniform> - %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<3x1x3x!quant.uniform>, tensor<3x3x2x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> + %0 = stablehlo.constant() {value = dense<1> : tensor<3x3x2xi8>} : () -> tensor<3x3x2x!quant.uniform> + %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<3x1x3x!quant.uniform>, tensor<3x3x2x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> return %1 : tensor<3x1x2x!quant.uniform> } // Nothing changes. @@ -459,8 +459,8 @@ func.func @dot_general_upstream_srq_per_axis_quantized_filter_multibatch(%arg0: // CHECK-LABEL: dot_general_upstream_srq_per_axis_quantized_filter_with_multiple_contracting_dims func.func @dot_general_upstream_srq_per_axis_quantized_filter_with_multiple_contracting_dims(%arg0: tensor<1x2x3x!quant.uniform>) -> tensor<1x1x!quant.uniform> { - %0 = stablehlo.constant() {value = dense<1> : tensor<1x3x2xi8>} : () -> tensor<1x3x2x!quant.uniform> - %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1, 2] x [2, 1] : (tensor<1x2x3x!quant.uniform>, tensor<1x3x2x!quant.uniform>) -> tensor<1x1x!quant.uniform> + %0 = stablehlo.constant() {value = dense<1> : tensor<1x3x2xi8>} : () -> tensor<1x3x2x!quant.uniform> + %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1, 2] x [2, 1] : (tensor<1x2x3x!quant.uniform>, tensor<1x3x2x!quant.uniform>) -> tensor<1x1x!quant.uniform> return %1 : tensor<1x1x!quant.uniform> } // Nothing changes. @@ -486,17 +486,17 @@ func.func @dot_general_upstream_srq_per_axis_quantized_filter_with_multiple_cont // * dot_general_with_relu6_fn func.func @dot_general_srq(%arg0: tensor<1x1024x!quant.uniform>) -> (tensor<1x3x!quant.uniform>) { - %0 = stablehlo.constant() {value = dense<1> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform:f32:1, {2.000000e+0, 2.000000e+0}>> - %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<1x1024x!quant.uniform>, tensor<1024x3x!quant.uniform:f32:1, {2.000000e+0, 2.000000e+0}>>) -> tensor<1x3x!quant.uniform> + %0 = stablehlo.constant() {value = dense<1> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform:f32:1, {2.000000e+0, 2.000000e+0, 2.000000e+0}>> + %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<1x1024x!quant.uniform>, tensor<1024x3x!quant.uniform:f32:1, {2.000000e+0, 2.000000e+0, 2.000000e+0}>>) -> tensor<1x3x!quant.uniform> %2 = stablehlo.uniform_quantize %1 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> return %2 : tensor<1x3x!quant.uniform> } // CHECK-LABEL: dot_general_srq // CHECK-SAME: (%[[ARG_1:.+]]: tensor<1x1024x!quant.uniform) -> tensor<1x3x!quant.uniform> // CHECK-NOT: stablehlo.dot_general -// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x1024x!quant.uniform:f32:0, {2.000000e+00,2.000000e+00}>>, value = dense<1> : tensor<3x1024xi8>} : () -> tensor<3x1024x!quant.uniform:f32:0, {2.000000e+00,2.000000e+00}>> -// CHECK: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x!quant.uniform>, value = dense<0> : tensor<3xi32>} : () -> tensor<3x!quant.uniform> -// CHECK: %[[FULLY_CONNECTED:.+]] = "tfl.fully_connected"(%[[ARG_1]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1024x!quant.uniform>, tensor<3x1024x!quant.uniform:f32:0, {2.000000e+00,2.000000e+00}>>, tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<3x1024x!quant.uniform:f32:0, {2.000000e+00,2.000000e+00,2.000000e+00}>>, value = dense<1> : tensor<3x1024xi8>}> : () -> tensor<3x1024x!quant.uniform:f32:0, {2.000000e+00,2.000000e+00,2.000000e+00}>> +// CHECK: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<3x!quant.uniform>, value = dense<0> : tensor<3xi32>}> : () -> tensor<3x!quant.uniform> +// CHECK: %[[FULLY_CONNECTED:.+]] = "tfl.fully_connected"(%[[ARG_1]], %[[QCONST_0]], %[[QCONST_1]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x1024x!quant.uniform>, tensor<3x1024x!quant.uniform:f32:0, {2.000000e+00,2.000000e+00,2.000000e+00}>>, tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> // CHECK-NOT: tfl.batch_matmul // CHECK: return %[[FULLY_CONNECTED]] @@ -516,9 +516,9 @@ func.func @dot_general_with_bias_same_shape_srq(%arg0: tensor<1x1024x!quant.unif } // CHECK-LABEL: dot_general_with_bias_same_shape // CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x1024x!quant.uniform>) -> tensor<1x3x!quant.uniform> -// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x1024x!quant.uniform:f32:0, {2.000000e+00,2.000000e+00,2.000000e+00}>>, value = dense<1> : tensor<3x1024xi8>} : () -> tensor<3x1024x!quant.uniform:f32:0, {2.000000e+00,2.000000e+00,2.000000e+00}>> -// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x!quant.uniform>, value = dense<2> : tensor<1x3xi32>} : () -> tensor<3x!quant.uniform> -// CHECK: %[[FULLY_CONNECTED:.+]] = "tfl.fully_connected"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1024x!quant.uniform>, tensor<3x1024x!quant.uniform:f32:0, {2.000000e+00,2.000000e+00,2.000000e+00}>>, tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<3x1024x!quant.uniform:f32:0, {2.000000e+00,2.000000e+00,2.000000e+00}>>, value = dense<1> : tensor<3x1024xi8>}> : () -> tensor<3x1024x!quant.uniform:f32:0, {2.000000e+00,2.000000e+00,2.000000e+00}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<3x!quant.uniform>, value = dense<2> : tensor<1x3xi32>}> : () -> tensor<3x!quant.uniform> +// CHECK: %[[FULLY_CONNECTED:.+]] = "tfl.fully_connected"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x1024x!quant.uniform>, tensor<3x1024x!quant.uniform:f32:0, {2.000000e+00,2.000000e+00,2.000000e+00}>>, tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> // CHECK: return %[[FULLY_CONNECTED]] // ----- @@ -542,9 +542,9 @@ func.func @dot_general_srq_constant_transpose_rhs(%arg0: tensor<1x3x!quant.unifo // Checks that the `tfl.pseudo_qconst` corresponding to the `stablehlo.constant` // has the same shape. -// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x!quant.uniform>, value = dense<1> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform> -// CHECK: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> -// CHECK: %[[FULLY_CONNECTED:.+]] = "tfl.fully_connected"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x3x!quant.uniform>, tensor<2x3x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x3x!quant.uniform>, value = dense<1> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform> +// CHECK: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> +// CHECK: %[[FULLY_CONNECTED:.+]] = "tfl.fully_connected"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x3x!quant.uniform>, tensor<2x3x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<1x2x!quant.uniform> // Also checks that the i32 -> i8 uniform quantize is absorbed into // `tfl.fully_connected`. @@ -557,9 +557,9 @@ func.func @dot_general_srq_constant_transpose_rhs(%arg0: tensor<1x3x!quant.unifo // (e.g. argument), the conversion to `tfl.fully_connected` doesn't happen. // CHECK-LABEL: dot_general_srq_arg_transpose_rhs -func.func @dot_general_srq_arg_transpose_rhs(%arg0: tensor<1x3x!quant.uniform>, %arg1: tensor<2x3x!quant.uniform>) -> tensor<1x2x!quant.uniform> { - %1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<2x3x!quant.uniform>) -> tensor<3x2x!quant.uniform> - %2 = stablehlo.dot_general %arg0, %1, contracting_dims = [1] x [0] : (tensor<1x3x!quant.uniform>, tensor<3x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +func.func @dot_general_srq_arg_transpose_rhs(%arg0: tensor<1x3x!quant.uniform>, %arg1: tensor<2x3x!quant.uniform>) -> tensor<1x2x!quant.uniform> { + %1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<2x3x!quant.uniform>) -> tensor<3x2x!quant.uniform> + %2 = stablehlo.dot_general %arg0, %1, contracting_dims = [1] x [0] : (tensor<1x3x!quant.uniform>, tensor<3x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> %3 = stablehlo.uniform_quantize %2 : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> return %3 : tensor<1x2x!quant.uniform> } @@ -577,7 +577,7 @@ func.func @dot_general_srq_arg_transpose_rhs(%arg0: tensor<1x3x!quant.uniform qi8 requantization is // properly lowered to `tfl.batch_matmul`. -func.func @dot_general_srq_to_batch_matmul(%arg0: tensor<1x2x3x4x!quant.uniform>, %arg1: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +func.func @dot_general_srq_to_batch_matmul(%arg0: tensor<1x2x3x4x!quant.uniform>, %arg1: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0, 1], @@ -586,14 +586,14 @@ func.func @dot_general_srq_to_batch_matmul(%arg0: tensor<1x2x3x4x!quant.uniform< rhs_contracting_dimensions = [2] >, precision_config = [#stablehlo, #stablehlo] - } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> + } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> %1 = stablehlo.uniform_quantize %0 : (tensor<1x2x3x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %1 : tensor<1x2x3x5x!quant.uniform> } // CHECK-LABEL: dot_general_srq_to_batch_matmul -// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x2x3x4x!quant.uniform>, %[[ARG_1:.+]]: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> -// CHECK: %[[BMM:.+]] = "tfl.batch_matmul"(%[[ARG_0]], %[[ARG_1]]) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x2x3x4x!quant.uniform>, %[[ARG_1:.+]]: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> +// CHECK: %[[BMM:.+]] = "tfl.batch_matmul"(%[[ARG_0]], %[[ARG_1]]) <{adj_x = false, adj_y = false}> : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> // CHECK-NOT: stablehlo.dot_general // CHECK-NOT: stablehlo.uniform_quantize // CHECK-NOT: tfl.fully_connected @@ -606,7 +606,7 @@ func.func @dot_general_srq_to_batch_matmul(%arg0: tensor<1x2x3x4x!quant.uniform< // not converted to `tfl.batch_matmul` when there are multiple use of the // intermediate result. -func.func @dot_general_srq_multiple_use_of_intermediate_result(%arg0: tensor<1x2x3x4x!quant.uniform>, %arg1: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +func.func @dot_general_srq_multiple_use_of_intermediate_result(%arg0: tensor<1x2x3x4x!quant.uniform>, %arg1: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0, 1], @@ -615,7 +615,7 @@ func.func @dot_general_srq_multiple_use_of_intermediate_result(%arg0: tensor<1x2 rhs_contracting_dimensions = [2] >, precision_config = [#stablehlo, #stablehlo] - } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> + } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> %1 = stablehlo.uniform_quantize %0 : (tensor<1x2x3x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> %2 = stablehlo.uniform_quantize %0 : (tensor<1x2x3x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> %3 = stablehlo.add %1, %2 : tensor<1x2x3x5x!quant.uniform> @@ -646,11 +646,11 @@ func.func @conv_srq(%arg0: tensor<1x5x5x2x!quant.uniform } // CHECK-LABEL: func.func @conv_srq // CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x5x5x2x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> -// CHECK-DAG: %[[CONST_0:.+]] = "tfl.pseudo_const"() {value = dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32>} : () -> tensor<4x2xi32> -// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<4x4x4x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<4x4x4x2xi8>} : () -> tensor<4x4x4x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>> -// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<4x!quant.uniform>, value = dense<0> : tensor<4xi32>} : () -> tensor<4x!quant.uniform> +// CHECK-DAG: %[[CONST_0:.+]] = "tfl.pseudo_const"() <{value = dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32>}> : () -> tensor<4x2xi32> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<4x4x4x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<4x4x4x2xi8>}> : () -> tensor<4x4x4x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<4x!quant.uniform>, value = dense<0> : tensor<4xi32>}> : () -> tensor<4x!quant.uniform> // CHECK: %[[PAD:.+]] = "tfl.pad"(%[[ARG_0]], %[[CONST_0]]) : (tensor<1x5x5x2x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x7x7x2x!quant.uniform> -// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x7x7x2x!quant.uniform>, tensor<4x4x4x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, tensor<4x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> +// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_0]], %[[QCONST_1]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x7x7x2x!quant.uniform>, tensor<4x4x4x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, tensor<4x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> // CHECK: return %[[CONV_2D]] func.func @conv_same_padding_srq(%arg0: tensor<1x32x32x3x!quant.uniform>) -> (tensor<1x32x32x2x!quant.uniform>) { @@ -661,9 +661,9 @@ func.func @conv_same_padding_srq(%arg0: tensor<1x32x32x3x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> -// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<2x3x3x3xi8>} : () -> tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>> -// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> -// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x3x!quant.uniform>, tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, tensor<2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<2x3x3x3xi8>}> : () -> tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x32x32x3x!quant.uniform>, tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, tensor<2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> // CHECK: return %[[CONV_2D]] : tensor<1x32x32x2x!quant.uniform> // ----- @@ -676,9 +676,9 @@ func.func @conv_same_padding_srq_non_unit_strides(%arg0: tensor<1x32x32x3x!quant } // CHECK-LABEL: func.func @conv_same_padding_srq_non_unit_strides // CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x32x32x3x!quant.uniform>) -> tensor<1x16x16x2x!quant.uniform> -// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<2x3x3x3xi8>} : () -> tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>> -// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> -// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x32x32x3x!quant.uniform>, tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, tensor<2x!quant.uniform>) -> tensor<1x16x16x2x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<2x3x3x3xi8>}> : () -> tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x32x32x3x!quant.uniform>, tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, tensor<2x!quant.uniform>) -> tensor<1x16x16x2x!quant.uniform> // CHECK: return %[[CONV_2D]] : tensor<1x16x16x2x!quant.uniform> // ----- @@ -692,11 +692,11 @@ func.func @conv_srq_transpose_conv(%arg0: tensor<1x5x5x2x!quant.uniform>) -> tensor<1x14x14x4x!quant.uniform> // CHECK-DAG: %[[CONST_0:.+]] = arith.constant dense<[1, 14, 14, 4]> : tensor<4xi32> -// CHECK-DAG: %[[CONST_1:.*]] = "tfl.pseudo_const"() {value = dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32>} : () -> tensor<4x2xi32> -// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<4x2x2x2xi8>} : () -> tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>> -// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<4x!quant.uniform>, value = dense<0> : tensor<4xi32>} : () -> tensor<4x!quant.uniform> +// CHECK-DAG: %[[CONST_1:.*]] = "tfl.pseudo_const"() <{value = dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32>}> : () -> tensor<4x2xi32> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<4x2x2x2xi8>}> : () -> tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<4x!quant.uniform>, value = dense<0> : tensor<4xi32>}> : () -> tensor<4x!quant.uniform> // CHECK: %[[PAD:.+]] = "tfl.pad"(%[[ARG_0]], %[[CONST_1]]) : (tensor<1x5x5x2x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x7x7x2x!quant.uniform> -// CHECK: %[[TRANSPOSE_CONV_2D:.+]] = "tfl.transpose_conv"(%[[CONST_0]], %[[QCONST_0]], %[[PAD]], %[[QCONST_1]]) {fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 4 : i32} : (tensor<4xi32>, tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, tensor<1x7x7x2x!quant.uniform>, tensor<4x!quant.uniform>) -> tensor<1x14x14x4x!quant.uniform> +// CHECK: %[[TRANSPOSE_CONV_2D:.+]] = "tfl.transpose_conv"(%[[CONST_0]], %[[QCONST_0]], %[[PAD]], %[[QCONST_1]]) <{fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 4 : i32}> : (tensor<4xi32>, tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, tensor<1x7x7x2x!quant.uniform>, tensor<4x!quant.uniform>) -> tensor<1x14x14x4x!quant.uniform> // CHECK: return %[[TRANSPOSE_CONV_2D]] // ----- @@ -721,11 +721,11 @@ func.func @conv_with_bias_and_relu_srq(%arg0: tensor<1x5x5x2x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> -// CHECK-DAG: %[[CONST_0:.+]] = "tfl.pseudo_const"() {value = dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32>} : () -> tensor<4x2xi32> -// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<4x4x4x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<4x4x4x2xi8>} : () -> tensor<4x4x4x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>> -// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<4x!quant.uniform>, value = dense<5> : tensor<1x1x1x4xi32>} : () -> tensor<4x!quant.uniform> +// CHECK-DAG: %[[CONST_0:.+]] = "tfl.pseudo_const"() <{value = dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32>}> : () -> tensor<4x2xi32> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<4x4x4x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<4x4x4x2xi8>}> : () -> tensor<4x4x4x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<4x!quant.uniform>, value = dense<5> : tensor<1x1x1x4xi32>}> : () -> tensor<4x!quant.uniform> // CHECK: %[[PAD:.+]] = "tfl.pad"(%[[ARG_0]], %[[CONST_0]]) : (tensor<1x5x5x2x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x7x7x2x!quant.uniform> -// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x7x7x2x!quant.uniform>, tensor<4x4x4x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, tensor<4x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> +// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_0]], %[[QCONST_1]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x7x7x2x!quant.uniform>, tensor<4x4x4x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, tensor<4x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> // CHECK: return %[[CONV_2D]] func.func @conv_with_bias_same_padding_srq(%arg0: tensor<1x32x32x3x!quant.uniform>) -> (tensor<1x32x32x2x!quant.uniform>) { @@ -739,9 +739,9 @@ func.func @conv_with_bias_same_padding_srq(%arg0: tensor<1x32x32x3x!quant.unifor } // CHECK-LABEL: func.func @conv_with_bias_same_padding_srq // CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x32x32x3x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> -// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<2x3x3x3xi8>} : () -> tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>> -// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<5> : tensor<1x1x1x2xi32>} : () -> tensor<2x!quant.uniform> -// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x3x!quant.uniform>, tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, tensor<2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<2x3x3x3xi8>}> : () -> tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x!quant.uniform>, value = dense<5> : tensor<1x1x1x2xi32>}> : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x32x32x3x!quant.uniform>, tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, tensor<2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> // CHECK: return %[[CONV_2D]] func.func @conv_with_bias_same_padding_srq_depthwise(%arg0: tensor<1x4x5x3x!quant.uniform>) -> (tensor<1x5x6x3x!quant.uniform>) { @@ -755,11 +755,11 @@ func.func @conv_with_bias_same_padding_srq_depthwise(%arg0: tensor<1x4x5x3x!quan } // CHECK-LABEL: func.func @conv_with_bias_same_padding_srq_depthwise // CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x4x5x3x!quant.uniform>) -> tensor<1x5x6x3x!quant.uniform> -// CHECK-DAG: %[[CONST_0:.+]] = "tfl.pseudo_const"() {value = dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32>} : () -> tensor<4x2xi32> -// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x2x3x!quant.uniform:f32:3, {3.000000e+00,3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<1x2x2x3xi8>} : () -> tensor<1x2x2x3x!quant.uniform:f32:3, {3.000000e+00,3.000000e+00,3.000000e+00}>> -// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x!quant.uniform>, value = dense<5> : tensor<1x1x1x3xi32>} : () -> tensor<3x!quant.uniform> +// CHECK-DAG: %[[CONST_0:.+]] = "tfl.pseudo_const"() <{value = dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32>}> : () -> tensor<4x2xi32> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x2x2x3x!quant.uniform:f32:3, {3.000000e+00,3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<1x2x2x3xi8>}> : () -> tensor<1x2x2x3x!quant.uniform:f32:3, {3.000000e+00,3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<3x!quant.uniform>, value = dense<5> : tensor<1x1x1x3xi32>}> : () -> tensor<3x!quant.uniform> // CHECK: %[[PAD:.+]] = "tfl.pad"(%[[ARG_0]], %[[CONST_0]]) : (tensor<1x4x5x3x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x6x7x3x!quant.uniform> -// CHECK: %[[DEPTHWISE_CONV_2D:.+]] = "tfl.depthwise_conv_2d"(%[[PAD]], %[[QCONST_0]], %[[QCONST_1]]) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x6x7x3x!quant.uniform>, tensor<1x2x2x3x!quant.uniform:f32:3, {3.000000e+00,3.000000e+00,3.000000e+00}>>, tensor<3x!quant.uniform>) -> tensor<1x5x6x3x!quant.uniform> +// CHECK: %[[DEPTHWISE_CONV_2D:.+]] = "tfl.depthwise_conv_2d"(%[[PAD]], %[[QCONST_0]], %[[QCONST_1]]) <{depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x6x7x3x!quant.uniform>, tensor<1x2x2x3x!quant.uniform:f32:3, {3.000000e+00,3.000000e+00,3.000000e+00}>>, tensor<3x!quant.uniform>) -> tensor<1x5x6x3x!quant.uniform> // CHECK: return %[[DEPTHWISE_CONV_2D]] // ----- @@ -872,7 +872,7 @@ func.func @concatenate( // CHECK-LABEL: concatenate // CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x!quant.uniform>, %[[ARG1:.+]]: tensor<1x2x!quant.uniform> // CHECK-NOT: stablehlo.concatenate -// CHECK: %[[CONCAT:.+]] = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<3x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<4x2x!quant.uniform> +// CHECK: %[[CONCAT:.+]] = "tfl.concatenation"(%arg0, %arg1) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<4x2x!quant.uniform> // CHECK: return %[[CONCAT]] // ----- @@ -998,7 +998,7 @@ func.func @strided_slice( // CHECK{LITERAL}: dense<[3, 4]> : tensor<2xi32> // CHECK: %[[STRIDE:.+]] = arith.constant // CHECK{LITERAL}: dense<[2, 3]> : tensor<2xi32> -// CHECK: %[[SLICE:.+]] = "tfl.strided_slice"(%[[ARG0]], %[[START]], %[[SIZE]], %[[STRIDE]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<3x6x!quant.uniform>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2x!quant.uniform> +// CHECK: %[[SLICE:.+]] = "tfl.strided_slice"(%[[ARG0]], %[[START]], %[[SIZE]], %[[STRIDE]]) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<3x6x!quant.uniform>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2x!quant.uniform> // CHECK: return %[[SLICE]] // ----- @@ -1456,7 +1456,7 @@ func.func @dynamic_slice( // CHECK: %[[MIN1:.+]] = "tfl.minimum"(%[[BITCAST1]], %[[MAX1]]) : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> // CHECK: %[[BITCAST2:.+]] = "tfl.bitcast"(%[[ARG2]]) : (tensor) -> tensor<1xi64> // CHECK: %[[MIN2:.+]] = "tfl.minimum"(%[[BITCAST2]], %[[MAX2]]) : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> -// CHECK: %[[CONCAT:.+]] = "tfl.concatenation"(%[[MIN1]], %[[MIN2]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> +// CHECK: %[[CONCAT:.+]] = "tfl.concatenation"(%[[MIN1]], %[[MIN2]]) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> // CHECK: %[[MAX:.+]] = "tfl.maximum"(%[[CONCAT]], %[[ZERO]]) : (tensor<2xi64>, tensor<1xi64>) -> tensor<2xi64> // CHECK: %[[SLICE:.+]] = "tfl.slice"(%[[ARG0]], %[[MAX]], %[[SLICE_SIZE]]) // CHECK-SAME: (tensor<4x4x!quant.uniform>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x1x!quant.uniform> @@ -1490,7 +1490,7 @@ func.func @add(%arg0: tensor<1x3x!quant.uniform>, %arg1: } // CHECK-LABEL: func @add -// CHECK: %[[ADD:.+]] = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x3x!quant.uniform>, tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[ADD:.+]] = tfl.add(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x3x!quant.uniform>, tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> // CHECK: return %[[ADD]] // ----- @@ -1517,7 +1517,7 @@ func.func @quantized_constant() -> tensor<1x2x4x5x!quant.uniform> } -// CHECK: %[[QCONST:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} +// CHECK: %[[QCONST:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>}> // CHECK-SAME: () -> tensor<1x2x4x5x!quant.uniform> // CHECK: return %[[QCONST]] @@ -1556,27 +1556,67 @@ func.func @dot_general_hybrid(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x5xf32> return %1 : tensor<1x2x3x5xf32> } -// CHECK: %[[WEIGHT:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} +// CHECK: %[[WEIGHT:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>}> // CHECK: %[[DQ:.+]] = "tfl.dequantize"(%[[WEIGHT]]) : (tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x4x5xf32> // CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG0]], %[[DQ]], batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2], precision = [DEFAULT, DEFAULT] : (tensor<1x2x3x4xf32>, tensor<1x2x4x5xf32>) -> tensor<1x2x3x5xf32> // CHECK: return %[[DOT]] // ----- -// Tests that a hybrid quantized convolution is splitted into dequantize and -// float convolution. +// Tests that a hybrid per-channel quantized convolution for tfl.conv_2d is +// splitted into dequantize and float stablehlo.convolution. -// CHECK-LABEL: func @convolution_hybrid +// CHECK-LABEL: func @convolution_hybrid_per_channel // CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x3x4xf32> -func.func @convolution_hybrid(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x2xf32> { +func.func @convolution_hybrid_per_channel(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x2xf32> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2xf32> return %1 : tensor<1x3x3x2xf32> } -// CHECK: %[[WEIGHT:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x3x4x2x!quant.uniform>, value = dense<3> : tensor<3x3x4x2xi8>} -// CHECK: %[[DQ:.+]] = "tfl.dequantize"(%[[WEIGHT]]) : (tensor<3x3x4x2x!quant.uniform>) -> tensor<3x3x4x2xf32> +// CHECK: %[[WEIGHT:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>}> +// CHECK: %[[DQ:.+]] = "tfl.dequantize"(%[[WEIGHT]]) : (tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>) -> tensor<2x3x3x4xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG0]], %[[DQ]]) +// CHECK{LITERAL}: dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} +// CHECK-SAME: (tensor<1x3x3x4xf32>, tensor<2x3x3x4xf32>) -> tensor<1x3x3x2xf32> +// CHECK: return %[[CONV]] + +// ----- + +// Tests that a hybrid per-tensor quantized convolution for tfl.conv_2d is +// splitted into dequantize and float stablehlo.convolution. + +// CHECK-LABEL: func @convolution_hybrid_per_tensor +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x3x4xf32> +func.func @convolution_hybrid_per_tensor(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x2xf32> { + %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2xf32> + return %1 : tensor<1x3x3x2xf32> +} + +// CHECK: %[[WEIGHT:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x3x3x4x!quant.uniform>, value = dense<3> : tensor<2x3x3x4xi8>}> +// CHECK: %[[DQ:.+]] = "tfl.dequantize"(%[[WEIGHT]]) : (tensor<2x3x3x4x!quant.uniform>) -> tensor<2x3x3x4xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG0]], %[[DQ]]) +// CHECK{LITERAL}: dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} +// CHECK-SAME: (tensor<1x3x3x4xf32>, tensor<2x3x3x4xf32>) -> tensor<1x3x3x2xf32> +// CHECK: return %[[CONV]] + +// ----- + +// Tests that a hybrid per-channel quantized convolution for tfl.depthwise_conv +// is splitted into dequantize and float stablehlo.convolution. + +// CHECK-LABEL: func @depthwise_convolution_hybrid_per_channel +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x3x4xf32> +func.func @depthwise_convolution_hybrid_per_channel(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x1x4xi8>} : () -> tensor<3x3x1x4x!quant.uniform> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 4 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x1x4x!quant.uniform>) -> tensor<1x3x3x4xf32> + return %1 : tensor<1x3x3x4xf32> +} + +// CHECK: %[[WEIGHT:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x3x3x4x!quant.uniform:f32:3, {2.000000e+02,3.000000e+03,2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<1x3x3x4xi8>}> +// CHECK: %[[DQ:.+]] = "tfl.dequantize"(%[[WEIGHT]]) : (tensor<1x3x3x4x!quant.uniform:f32:3, {2.000000e+02,3.000000e+03,2.000000e+02,3.000000e+03}>>) -> tensor<1x3x3x4xf32> // CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG0]], %[[DQ]]) -// CHECK{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} -// CHECK-SAME: (tensor<1x3x3x4xf32>, tensor<3x3x4x2xf32>) -> tensor<1x3x3x2xf32> +// CHECK{LITERAL}: dim_numbers = [b, 0, 1, f]x[i, 0, 1, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 4 : i64} +// CHECK-SAME: (tensor<1x3x3x4xf32>, tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> // CHECK: return %[[CONV]] diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc index 15481b9a0a1ad2..52f2c4be02a3aa 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc @@ -73,7 +73,7 @@ bool IsI8ToF32Cast(stablehlo::ConvertOp convert_op) { const bool is_i8_operand = convert_op.getOperand().getType().getElementType().isInteger(/*width=*/8); const bool is_f32_result = - convert_op.getResult().getType().getElementType().isa(); + mlir::isa(convert_op.getResult().getType().getElementType()); return is_i8_operand && is_f32_result; } @@ -92,7 +92,7 @@ bool IsI32ToF32Cast(stablehlo::ConvertOp convert_op) { convert_op.getOperand().getType().getElementType().isInteger( /*width=*/32); const bool is_f32_result = - convert_op.getResult().getType().getElementType().isa(); + mlir::isa(convert_op.getResult().getType().getElementType()); return is_i32_operand && is_f32_result; } @@ -104,7 +104,8 @@ LogicalResult MatchZeroPointsOperand(Value zero_points) { return failure(); } - auto zero_points_type = zero_points.getType().dyn_cast_or_null(); + auto zero_points_type = + mlir::dyn_cast_or_null(zero_points.getType()); if (!zero_points_type) { LLVM_DEBUG(llvm::dbgs() << "Zero point value should be a tensor type. Got: " << zero_points_type << ".\n"); @@ -112,7 +113,7 @@ LogicalResult MatchZeroPointsOperand(Value zero_points) { } if (Type zero_points_element_type = zero_points_type.getElementType(); - !zero_points_element_type.isa()) { + !mlir::isa(zero_points_element_type)) { LLVM_DEBUG(llvm::dbgs() << "Zero point should be an integer type. Got: " << zero_points_element_type << ".\n"); return failure(); @@ -146,7 +147,7 @@ LogicalResult MatchInverseScalesOperand(Value inverse_scales) { } auto inverse_scales_type = - inverse_scales.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(inverse_scales.getType()); if (!inverse_scales_type) { LLVM_DEBUG(llvm::dbgs() << "Inverse scales should be a tensor type. Got: " << inverse_scales_type << ".\n"); @@ -154,7 +155,7 @@ LogicalResult MatchInverseScalesOperand(Value inverse_scales) { } if (Type inverse_scales_element_type = inverse_scales_type.getElementType(); - !inverse_scales_element_type.isa()) { + !mlir::isa(inverse_scales_element_type)) { LLVM_DEBUG(llvm::dbgs() << "Inverse scales element should be a float type. Got: " << inverse_scales_element_type << ".\n"); @@ -207,7 +208,7 @@ class UniformQuantizeFunctionCallPattern { } auto input_value_type = - input_value.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(input_value.getType()); if (!input_value_type) { LLVM_DEBUG(llvm::dbgs() << "Failed to match @uniform_quantize function call pattern. " @@ -216,7 +217,7 @@ class UniformQuantizeFunctionCallPattern { } if (Type input_element_type = input_value_type.getElementType(); - !input_element_type.isa()) { + !mlir::isa(input_element_type)) { LLVM_DEBUG(llvm::dbgs() << "Failed to match @uniform_quantize function call pattern. " "Input value's element type must be a float. Got: " @@ -299,7 +300,7 @@ class UniformDequantizeFunctionCallPattern { } auto input_value_type = - input_value.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(input_value.getType()); if (!input_value_type) { LLVM_DEBUG(llvm::dbgs() << "Failed to match @uniform_dequantize call pattern. Input " @@ -309,7 +310,7 @@ class UniformDequantizeFunctionCallPattern { } if (Type input_element_type = input_value_type.getElementType(); - !input_element_type.isa()) { + !mlir::isa(input_element_type)) { LLVM_DEBUG(llvm::dbgs() << "Failed to match @uniform_dequantize call pattern. Input " "value's element type must be integer. Got: " @@ -433,8 +434,9 @@ class ComposeUniformQuantizedConvolutionOp LogicalResult match(stablehlo::ConvolutionOp op) const final { // Verify operands' types. for (Type operand_type : op.getOperandTypes()) { - if (Type element_type = operand_type.cast().getElementType(); - !element_type.isa()) { + if (Type element_type = + mlir::cast(operand_type).getElementType(); + !mlir::isa(element_type)) { LLVM_DEBUG(llvm::dbgs() << "Failed to match. The operand type must be a float. Got: " << element_type << ".\n"); @@ -477,8 +479,9 @@ class ComposeUniformQuantizedConvolutionOp // Match the subgraph that receives the convolution output. Value conv_output_value = op.getResult(); if (auto output_element_type = - conv_output_value.getType().cast().getElementType(); - !output_element_type.isa()) { + mlir::cast(conv_output_value.getType()) + .getElementType(); + !mlir::isa(output_element_type)) { LLVM_DEBUG( llvm::dbgs() << "Failed to match. Output type is expected to be a float. Got: " @@ -530,14 +533,12 @@ class ComposeUniformQuantizedConvolutionOp return failure(); } - if (!(other_zp_i8_to_f32_convert_op.getResult() - .getType() - .getElementType() - .isa() && - other_zp_i8_to_f32_convert_op.getOperand() - .getType() - .getElementType() - .isa())) { + if (!(mlir::isa(other_zp_i8_to_f32_convert_op.getResult() + .getType() + .getElementType()) && + mlir::isa(other_zp_i8_to_f32_convert_op.getOperand() + .getType() + .getElementType()))) { LLVM_DEBUG( llvm::dbgs() << "Failed to match. The ConvertOp is not an i8->f32 type cast.\n"); @@ -671,8 +672,8 @@ class ComposeUniformQuantizedConvolutionOp rewriter.create( uniform_quantize_call_op.getLoc(), /*result=*/ - input_value.getType().cast().clone( - input_quantized_element_type), + mlir::cast(input_value.getType()) + .clone(input_quantized_element_type), /*operand=*/input_value); rewriter.replaceAllUsesWith(input_i8_to_f32_convert_op.getResult(), @@ -689,20 +690,21 @@ class ComposeUniformQuantizedConvolutionOp // This is i8 values disguised as f32 (due to the upcast trick). Simply // cast them to i8. ElementsAttr filter_value = filter_constant_op.getValue(); - filter_i8_value_attr = filter_value.cast().mapValues( - rewriter.getI8Type(), [](const APFloat& val) -> APInt { - APSInt convertedInt(/*BitWidth=*/8, /*isUnsigned=*/false); - bool ignored; - val.convertToInteger(convertedInt, APFloat::rmTowardZero, &ignored); - return convertedInt; - }); + filter_i8_value_attr = + mlir::cast(filter_value) + .mapValues(rewriter.getI8Type(), [](const APFloat& val) -> APInt { + APSInt convertedInt(/*BitWidth=*/8, /*isUnsigned=*/false); + bool ignored; + val.convertToInteger(convertedInt, APFloat::rmTowardZero, + &ignored); + return convertedInt; + }); } else if (isa(filter_op) && isa( filter_op->getOperand(0).getDefiningOp())) { - filter_i8_value_attr = + filter_i8_value_attr = mlir::cast( cast(filter_op->getOperand(0).getDefiningOp()) - .getValue() - .cast(); + .getValue()); } // Create Uniform Quantized constant for the filter. @@ -719,9 +721,9 @@ class ComposeUniformQuantizedConvolutionOp scale_combined_broadcast_in_dim_op.getOperand().getDefiningOp()); SmallVector filter_scale_values; - for (const auto combined_scale_value : combined_scale_constant_op.getValue() - .cast() - .getValues()) { + for (const auto combined_scale_value : + mlir::cast(combined_scale_constant_op.getValue()) + .getValues()) { // UniformQuantizedPerAxisType requires scales to have double dtype. const double filter_scale_value = static_cast( combined_scale_value * input_inverse_scales_value); @@ -780,7 +782,8 @@ class ComposeUniformQuantizedConvolutionOp Value conv_output_value = op.getResult(); auto output_uniform_quantized_tensor_type = RankedTensorType::getChecked( rewriter.getUnknownLoc(), - /*shape=*/conv_output_value.getType().cast().getShape(), + /*shape=*/ + mlir::cast(conv_output_value.getType()).getShape(), output_uniform_quantized_type); SmallVector new_conv_output_types = { @@ -1017,8 +1020,8 @@ class ComposeUniformQuantizedDotGeneralOp rewriter.create( input_i8_to_f32_convert_op.getLoc(), /*result=*/ - input_value.getType().cast().clone( - input_uniform_quantized_type), + mlir::cast(input_value.getType()) + .clone(input_uniform_quantized_type), /*operand=*/input_value); rewriter.replaceAllUsesWith(input_i8_to_f32_convert_op.getResult(), @@ -1029,13 +1032,13 @@ class ComposeUniformQuantizedDotGeneralOp stablehlo::ConstantOp filter_constant_op = GetFilterConstantOp(filter_value); auto filter_value_attr = - filter_constant_op.getValue().cast(); + mlir::cast(filter_constant_op.getValue()); if (filter_value_attr.getElementType().isF32()) { // This is i8 values disguised as f32 (due to the upcast trick). Simply // cast them to i8. filter_value_attr = - filter_value_attr.cast().mapValues( - rewriter.getI8Type(), [](const APFloat& val) -> APInt { + mlir::cast(filter_value_attr) + .mapValues(rewriter.getI8Type(), [](const APFloat& val) -> APInt { APSInt converted_int(/*BitWidth=*/8, /*isUnsigned=*/false); bool ignored; val.convertToInteger(converted_int, APFloat::rmTowardZero, @@ -1072,9 +1075,9 @@ class ComposeUniformQuantizedDotGeneralOp auto merged_scale_constant_op = cast(multiply_op_second_operand.getDefiningOp()); SmallVector filter_scale_values; - for (const auto merged_scale : merged_scale_constant_op.getValue() - .cast() - .getValues()) { + for (const auto merged_scale : + mlir::cast(merged_scale_constant_op.getValue()) + .getValues()) { // (s1 * s2) * (1 / s1) = s2 // UniformQuantizedPerAxisType requires scales to have double dtype. filter_scale_values.push_back( @@ -1086,7 +1089,7 @@ class ComposeUniformQuantizedDotGeneralOp const int quantization_dimension = GetFilterQuantizationDimension( op.getDotDimensionNumbers(), - filter_value_attr.getType().cast().getRank()); + mlir::cast(filter_value_attr.getType()).getRank()); const UniformQuantizedPerAxisType filter_uniform_quantized_type = CreateI8F32UniformQuantizedPerAxisType( filter_constant_op.getLoc(), *rewriter.getContext(), @@ -1097,8 +1100,8 @@ class ComposeUniformQuantizedDotGeneralOp auto quantized_filter_constant_op = rewriter.create( filter_constant_op.getLoc(), /*output=*/ - filter_constant_op.getResult().getType().cast().clone( - filter_uniform_quantized_type), + mlir::cast(filter_constant_op.getResult().getType()) + .clone(filter_uniform_quantized_type), /*value=*/filter_value_attr); rewriter.replaceAllUsesWith(filter_value, @@ -1137,8 +1140,8 @@ class ComposeUniformQuantizedDotGeneralOp auto new_dot_general_op = rewriter.create( op.getLoc(), /*resultType0=*/ - op.getResult().getType().cast().clone( - output_uniform_quantized_type), + mlir::cast(op.getResult().getType()) + .clone(output_uniform_quantized_type), /*lhs=*/op.getLhs(), /*rhs=*/op.getRhs(), /*dot_dimension_numbers=*/op.getDotDimensionNumbers(), /*precision_config=*/op.getPrecisionConfigAttr()); @@ -1395,8 +1398,8 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations rewriter.create( input1_uniform_quantize_call_op.getLoc(), /*result=*/ - input1_value.getType().cast().clone( - input1_uniform_quantized_type), + mlir::cast(input1_value.getType()) + .clone(input1_uniform_quantized_type), /*operand=*/input1_value); rewriter.replaceAllUsesWith(input1_zero_point_subtract_op.getResult(), @@ -1434,8 +1437,8 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations rewriter.create( input2_uniform_quantize_call_op.getLoc(), /*result=*/ - input2_value.getType().cast().clone( - input2_uniform_quantized_type), + mlir::cast(input2_value.getType()) + .clone(input2_uniform_quantized_type), /*operand=*/input2_value); rewriter.replaceAllUsesWith(input2_zero_point_subtract_op.getResult(), @@ -1482,8 +1485,8 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations auto new_dot_general_op = rewriter.create( op.getLoc(), /*resultType0=*/ - op.getResult().getType().cast().clone( - output_uniform_quantized_type), + mlir::cast(op.getResult().getType()) + .clone(output_uniform_quantized_type), /*lhs=*/op.getLhs(), /*rhs=*/op.getRhs(), /*dot_dimension_numbers=*/op.getDotDimensionNumbers(), /*precision_config=*/op.getPrecisionConfigAttr()); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.cc index 801c8775682cbd..8c28f2e5e5df4b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.cc @@ -55,7 +55,7 @@ DenseIntElementsAttr GetPaddingArrayAttr(Builder& builder, Operation* old_op) { } ShapedType GetPaddedType(Operation* old_op) { - auto input_type = old_op->getOperand(0).getType().cast(); + auto input_type = mlir::cast(old_op->getOperand(0).getType()); auto input_shape = input_type.getShape(); // NCHW int64_t batch_size = input_shape[0]; int64_t channel_size = input_shape[1]; @@ -124,7 +124,7 @@ StringAttr GetPaddingStringAttr(Builder& builder, Operation* old_op) { auto composite_attrs = composite_op.getCompositeAttributes(); auto operand_shape = - composite_op.getOperand(0).getType().cast().getShape(); + mlir::cast(composite_op.getOperand(0).getType()).getShape(); // NC(H)(W) std::vector spatial_dim_sizes = { static_cast(operand_shape[2]), diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td index 829cf2fbaf16a4..5b9324c2a1782b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td @@ -49,3 +49,24 @@ def LegalizeTorchUpsampleBlinear2dComposite: Pat< (Arith_ConstantOp ConstantAttr,"{0, 3, 1, 2}">)), [(IsSupportedNchwUpsampleBlinear $input, $old_val, $attrs)]>; + +// TODO(b/333961789): Add support for NCHW layout for PyTorch resize, plus jax +// supports NCHW inputs as well, so we need to add reliable way of checking the +// layout. +// pattern to lower a stablehlo.composite with `jax.image.resize` in `nearest` +// mode to a tflite.resize_nearest_neighbor op. +def LegalizeJaxResizeNearestNeighbor2dComposite: Pat< + (MHLO_CompositeOp:$old_val + (variadic $input), + ConstantStrAttr, $attrs, $_, $_), + (TFL_ResizeNearestNeighborOp + $input, + (Arith_ConstantOp:$output_size (GetI32DenseAttr (GetAsVectorAttr<"output_size"> $attrs))), + ConstBoolAttrFalse, + ConstBoolAttrTrue)>; + + +def LegalizeCompositeGELU : Pat<(MHLO_CompositeOp:$composite + (variadic $inputs), + ConstantStrAttr, $_, $_, $_), + (TFL_GeluOp $inputs, ConstBoolAttrFalse)>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.cc index 403bf9968a9acd..2809c81458918c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace odml { @@ -65,8 +66,8 @@ bool GetI32VectorFromDenseI64CompositeAttr( bool IsSupportedNchwUpsampleBlinear( Value input, Value output, const DenseIntElementsAttr& output_size_attr) { - auto input_shape = input.getType().cast().getShape(); - auto output_shape = output.getType().cast().getShape(); + auto input_shape = mlir::cast(input.getType()).getShape(); + auto output_shape = mlir::cast(output.getType()).getShape(); // Only support 4D tensor. if (input_shape.size() != 4 || output_shape.size() != 4) { @@ -89,7 +90,7 @@ bool IsSupportedNchwUpsampleBlinear( ShapedType GetNhwcReturnTypeFromNchw(Operation* old_op) { auto composite_result_shape = - old_op->getResults().front().getType().cast().getShape(); + mlir::cast(old_op->getResults().front().getType()).getShape(); std::array output_shape; // NHWC <- NCHW output_shape[0] = composite_result_shape[0]; @@ -97,7 +98,7 @@ ShapedType GetNhwcReturnTypeFromNchw(Operation* old_op) { output_shape[2] = composite_result_shape[3]; output_shape[3] = composite_result_shape[1]; - auto input_type = old_op->getOperand(0).getType().cast(); + auto input_type = mlir::cast(old_op->getOperand(0).getType()); return RankedTensorType::get(output_shape, input_type.getElementType()); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h index 0691dc74997212..79d0910bce18a4 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h @@ -38,10 +38,10 @@ template bool EnsureAttribute(const DictionaryAttr& composite_attributes, const std::string& attr_name, AttrType* out_attr) { Attribute attr = composite_attributes.get(attr_name); - if (!attr.isa_and_nonnull()) { + if (!mlir::isa_and_nonnull(attr)) { return false; } - if (AttrType content = attr.dyn_cast()) { + if (AttrType content = mlir::dyn_cast(attr)) { *out_attr = content; return true; } else { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td index d39a8efb8b13b3..30d6f4247fba52 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td @@ -22,7 +22,6 @@ include "mlir/IR/PatternBase.td" def GetNhwcReturnTypeFromNchw: NativeCodeCall< "GetNhwcReturnTypeFromNchw((*$0.begin()).getDefiningOp())">; - // When given a DenseIntElementsAttr containing I64 elements, this extracts // one I32IntegerAttr from the given index. class GetI32At: NativeCodeCall< diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc index 847738e5cc7cbe..c2b31aeb540720 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" @@ -117,7 +118,7 @@ static Attribute BinaryFolder(Op *op) { auto rhs = dyn_cast_or_null(rhs_op.getValue()); if (!lhs || !rhs) return {}; - ShapedType type = op->getType().template cast(); + ShapedType type = mlir::cast(op->getType()); if (!type.hasStaticShape()) { return {}; } @@ -125,15 +126,15 @@ static Attribute BinaryFolder(Op *op) { Type etype = type.getElementType(); // Evaluate for element types. - if (!etype.isa()) { + if (!mlir::isa(etype)) { return {}; } // Special case for folding splats no matter how large. // Only covers the case of both attrs being splats; operation-specific cases // like adding a zero or multiplying by one are handled elsewhere. - SplatElementsAttr splatLhs = lhs.template dyn_cast(); - SplatElementsAttr splatRhs = rhs.template dyn_cast(); + SplatElementsAttr splatLhs = mlir::dyn_cast(lhs); + SplatElementsAttr splatRhs = mlir::dyn_cast(rhs); if (splatLhs && splatRhs) { auto signedLhs = addSign(splatLhs.getSplatValue(), etype); auto signedRhs = addSign(splatRhs.getSplatValue(), etype); @@ -195,10 +196,10 @@ class FoldBroadcastInDimBeforeBinaryElementwiseOp auto bcast_dims = bcast_op.getBroadcastDimensions(); auto elem_type = const_val.getElementType(); Attribute result; - if (elem_type.template isa()) { + if (mlir::isa(elem_type)) { result = ConstFoldBroadcastInDim(result_type, const_val, bcast_dims); - } else if (elem_type.template isa()) { + } else if (mlir::isa(elem_type)) { result = ConstFoldBroadcastInDim(result_type, const_val, bcast_dims); } else { @@ -217,14 +218,14 @@ using FoldBroadcastInDimBeforeMulOp = // Constant folds mhlo.mul, this folder doesn't have an upper limit on how many // elements can be folded. LogicalResult ConstantFoldMul(mhlo::MulOp op, PatternRewriter &rewriter) { - ShapedType type = op.getType().dyn_cast(); + ShapedType type = mlir::dyn_cast(op.getType()); Type etype = type.getElementType(); Attribute result = {}; - if (etype.isa()) { + if (mlir::isa(etype)) { result = BinaryFolder>( &op); - } else if (etype.isa()) { + } else if (mlir::isa(etype)) { result = BinaryFolder>( &op); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.cc index 2d9308b05cb47b..9e7a5d424a2ecc 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.cc @@ -237,9 +237,9 @@ bool MatchReshapedIota(DenseIntElementsAttr dimensions, Value iota) { auto reshape_op = dyn_cast_or_null(iota.getDefiningOp()); if (!reshape_op) return false; auto operand_type = - reshape_op.getOperand().getType().dyn_cast(); + mlir::dyn_cast(reshape_op.getOperand().getType()); if (!operand_type || !operand_type.hasStaticShape()) return false; - auto reshape_type = reshape_op.getType().cast(); + auto reshape_type = mlir::cast(reshape_op.getType()); // Reshape can take a 1-D iota input and add extra dims of size one. if (operand_type.getRank() != 1) return false; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc index 3a483f44568ce2..96081a2b2b1bd8 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc @@ -162,10 +162,8 @@ class ConvertNdConvOp : public OpConversionPattern { } // tf Convolution doesn't support quantized type. - if (conv_op.getRhs() - .getType() - .getElementType() - .isa()) { + if (mlir::isa( + conv_op.getRhs().getType().getElementType())) { return failure(); } @@ -193,11 +191,11 @@ class ConvertNdConvOp : public OpConversionPattern { const int kernel_input_feature_dimension = dnums.getKernelInputFeatureDimension(); const int input_channels = - conv_op.getLhs().getType().cast().getDimSize( - input_feature_dimension); + mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_feature_dimension); const int kernel_input_channels = - conv_op.getRhs().getType().cast().getDimSize( - kernel_input_feature_dimension); + mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_input_feature_dimension); int feature_group_count = conv_op.getFeatureGroupCount(); // check if group count is valid @@ -238,14 +236,14 @@ class ConvertNdConvOp : public OpConversionPattern { }; static bool IsSupportedConvOp(mhlo::ConvolutionOp conv_op) { - if (!conv_op.getRhs().getType().cast().hasStaticShape()) { + if (!mlir::cast(conv_op.getRhs().getType()).hasStaticShape()) { return false; } - if (!conv_op.getLhs().getType().cast().hasStaticShape() && - !conv_op.getType().cast().hasStaticShape()) { + if (!mlir::cast(conv_op.getLhs().getType()).hasStaticShape() && + !mlir::cast(conv_op.getType()).hasStaticShape()) { auto dnums = conv_op.getDimensionNumbers(); - auto lhs_type = conv_op.getLhs().getType().cast(); - auto out_type = conv_op.getType().cast(); + auto lhs_type = mlir::cast(conv_op.getLhs().getType()); + auto out_type = mlir::cast(conv_op.getType()); int64_t input_batch_dim = dnums.getInputBatchDimension(); int64_t out_batch_dim = dnums.getOutputBatchDimension(); for (size_t i = 0; i < lhs_type.getRank(); ++i) { @@ -263,10 +261,7 @@ class ConvertNdConvOp : public OpConversionPattern { if (!lhs_dilation.isSplat() || lhs_dilation.getSplatValue() != 1) return false; - if (conv_op.getWindowStrides() - .value() - .getType() - .cast() + if (mlir::cast(conv_op.getWindowStrides().value().getType()) .getRank() != 1) return false; @@ -290,10 +285,10 @@ class ConvertNdConvOp : public OpConversionPattern { int64_t pad_low_int64; int64_t pad_high_int64; tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerbose( - conv_op.getLhs().getType().cast().getDimSize( - input_spatial_dim[i]), - conv_op.getRhs().getType().cast().getDimSize( - kernel_spatial_dim[i]), + mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_spatial_dim[i]), + mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_spatial_dim[i]), dilation[dim], strides[dim], tensorflow::Padding::SAME, &output_size, &pad_low_int64, &pad_high_int64); if (!status.ok()) return false; @@ -314,7 +309,7 @@ class ConvertNdConvOp : public OpConversionPattern { return value; } - auto input_type = value.getType().cast(); + auto input_type = mlir::cast(value.getType()); auto input_shape = input_type.getShape(); llvm::SmallVector start; @@ -380,7 +375,7 @@ class ConvertNdConvOp : public OpConversionPattern { // Convolution. This is needed because TF.Conv3DOp doesn't support EXPLICIT. if (padding == "EXPLICIT" && num_spatial_dims == 3) { auto lhs_type = - conv_op.getLhs().getType().template dyn_cast(); + mlir::dyn_cast(conv_op.getLhs().getType()); RankedTensorType padding_attr_type = mlir::RankedTensorType::get( {lhs_type.getRank(), 2}, rewriter.getIntegerType(64)); auto padding_const = rewriter.create( @@ -394,7 +389,7 @@ class ConvertNdConvOp : public OpConversionPattern { padding = "VALID"; } - auto conv_output_type = conv_op.getType().cast(); + auto conv_output_type = mlir::cast(conv_op.getType()); DenseIntElementsAttr permutation; const bool need_transpose_output = NeedsReformatTypeAndPermutation( dnums.getOutputBatchDimension(), dnums.getOutputFeatureDimension(), @@ -418,7 +413,7 @@ class ConvertNdConvOp : public OpConversionPattern { // Reshapes filter format to [filter_height, filter_width, in_channels, // channel_multiplier] from HLO's [filter_height, filter_width, 1, // in_channels * channel_multiplier] format. - auto filter_type = rhs.getType().cast(); + auto filter_type = mlir::cast(rhs.getType()); llvm::ArrayRef hlo_filter_shape = filter_type.getShape(); llvm::SmallVector tf_filter_shape(hlo_filter_shape.begin(), hlo_filter_shape.end()); @@ -491,13 +486,13 @@ class Convert1DConvOp : public OpConversionPattern { // Group convolution is not supported yet. const int64_t input_feature_dimension = dnums.getInputFeatureDimension(); const int64_t input_channels = - conv_op.getLhs().getType().cast().getDimSize( - input_feature_dimension); + mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_feature_dimension); const int kernel_input_feature_dimension = dnums.getKernelInputFeatureDimension(); const int kernel_input_channels = - conv_op.getRhs().getType().cast().getDimSize( - kernel_input_feature_dimension); + mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_input_feature_dimension); const int64_t feature_group_count = conv_op.getFeatureGroupCount(); if (feature_group_count != input_channels / kernel_input_channels || input_channels % kernel_input_channels != 0) @@ -508,7 +503,7 @@ class Convert1DConvOp : public OpConversionPattern { // // Reshape input image to add a new spatial dimension. - auto image_type = conv_op.getLhs().getType().cast(); + auto image_type = mlir::cast(conv_op.getLhs().getType()); SmallVector image_2d_shape(image_type.getShape().begin(), image_type.getShape().end()); image_2d_shape.push_back(1); @@ -530,7 +525,7 @@ class Convert1DConvOp : public OpConversionPattern { image_permutation_and_shape.permutation); // Reshape kernel to add a new spatial dimension. - auto kernel_type = conv_op.getRhs().getType().cast(); + auto kernel_type = mlir::cast(conv_op.getRhs().getType()); SmallVector kernel_2d_shape; for (int64_t dim : kernel_type.getShape()) { kernel_2d_shape.push_back(dim); @@ -623,7 +618,7 @@ class Convert1DConvOp : public OpConversionPattern { // // Determine the 2-D convolution output shape. - auto output_type = conv_op->getResult(0).getType().cast(); + auto output_type = mlir::cast(conv_op->getResult(0).getType()); SmallVector output_2d_shape; for (int64_t dim : output_type.getShape()) { output_2d_shape.push_back(dim); @@ -648,7 +643,7 @@ class Convert1DConvOp : public OpConversionPattern { conv_op.getPrecisionConfigAttr()); OpResult conv2d_output = conv2d_op->getResult(0); - auto conv2d_output_type = conv2d_output.getType().cast(); + auto conv2d_output_type = mlir::cast(conv2d_output.getType()); // // Transpose and reshape the output @@ -676,9 +671,9 @@ using Convert3DConvOp = ConvertNdConvOp<3>; // lhs_dilation>1 and window_strides=1. LogicalResult IsSupportedNonTrivialConvOp(mhlo::ConvolutionOp conv_op, ConversionPatternRewriter& rewriter) { - if (!conv_op.getLhs().getType().cast().hasStaticShape() || - !conv_op.getRhs().getType().cast().hasStaticShape() || - !conv_op.getType().cast().hasStaticShape()) + if (!mlir::cast(conv_op.getLhs().getType()).hasStaticShape() || + !mlir::cast(conv_op.getRhs().getType()).hasStaticShape() || + !mlir::cast(conv_op.getType()).hasStaticShape()) return rewriter.notifyMatchFailure(conv_op, "requires static shape"); mhlo::ConvDimensionNumbersAttr dnums = conv_op.getDimensionNumbers(); @@ -687,10 +682,7 @@ LogicalResult IsSupportedNonTrivialConvOp(mhlo::ConvolutionOp conv_op, return rewriter.notifyMatchFailure(conv_op, "requires non-trivial lhs_dilation"); - if (conv_op.getWindowStrides() - .value() - .getType() - .cast() + if (mlir::cast(conv_op.getWindowStrides().value().getType()) .getRank() != 1) return rewriter.notifyMatchFailure( conv_op, "requires window_strides to equal to one"); @@ -746,19 +738,19 @@ class ConvertToResizeBilinearOpOrDepthwiseTransposedConvOp mhlo::ConvDimensionNumbersAttr dnums = conv_op.getDimensionNumbers(); const int input_feature_dimension = dnums.getInputFeatureDimension(); const int input_channels = - conv_op.getLhs().getType().cast().getDimSize( - input_feature_dimension); + mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_feature_dimension); int feature_group_count = conv_op.getFeatureGroupCount(); const int kernel_input_feature_dimension = dnums.getKernelInputFeatureDimension(); const int kernel_input_channels = - conv_op.getRhs().getType().cast().getDimSize( - kernel_input_feature_dimension); + mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_input_feature_dimension); const int kernel_output_feature_dimension = dnums.getKernelOutputFeatureDimension(); const int kernel_output_channels = - conv_op.getRhs().getType().cast().getDimSize( - kernel_output_feature_dimension); + mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_output_feature_dimension); // To support a depthwise convolution, we need- // 1. feature_group_count != 1 (except when input_channels==1) @@ -795,7 +787,7 @@ class ConvertToResizeBilinearOpOrDepthwiseTransposedConvOp auto create_slice = [&](mlir::Value tensor, int depth_idx, int channel_idx, bool is_kernel = false) -> mlir::Value { std::vector tensor_shape = - tensor.getType().cast().getShape().vec(); + mlir::cast(tensor.getType()).getShape().vec(); // Calculate offsets based on depth_idx, channel_idx and tensor_shape std::vector start_indices(tensor_shape.size(), 0); @@ -828,7 +820,8 @@ class ConvertToResizeBilinearOpOrDepthwiseTransposedConvOp // Calculate convolution output_type based on sliced_input and // sliced_kernel - auto output_type = conv_op->getResult(0).getType().cast(); + auto output_type = + mlir::cast(conv_op->getResult(0).getType()); std::vector new_output_shape = output_type.getShape().vec(); new_output_shape[dnums.getOutputFeatureDimension()] /= feature_group_count; @@ -884,8 +877,8 @@ class ConvertToResizeBilinearOpOrDepthwiseTransposedConvOp int feature_group_count = conv_op.getFeatureGroupCount(); const int input_feature_dimension = dnums.getInputFeatureDimension(); const int input_channels = - conv_op.getLhs().getType().cast().getDimSize( - input_feature_dimension); + mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_feature_dimension); // Check for Group Convolution parameters if (feature_group_count != 1 && feature_group_count != input_channels) { @@ -919,7 +912,7 @@ class ConvertToResizeBilinearOpOrDepthwiseTransposedConvOp auto padding_values = padding.getValues(); // Cast the dimension sizes to int. - auto lhs_type = conv_op.getLhs().getType().cast(); + auto lhs_type = mlir::cast(conv_op.getLhs().getType()); llvm::SmallVector input_sizes = { static_cast(lhs_type.getDimSize(input_spatial_dimensions[0])), static_cast(lhs_type.getDimSize(input_spatial_dimensions[1]))}; @@ -1101,7 +1094,8 @@ class ConvertNonTrivialConvOp transpose_order[dnums.getOutputSpatialDimensions().data()[i]] = i + 1; } auto output_shape = - conv_op.getResult().getType().cast().getShape(); + mlir::cast(conv_op.getResult().getType()) + .getShape(); SmallVector transposed_output_shape = { output_shape[dnums.getOutputBatchDimension()], output_shape[dnums.getOutputSpatialDimensions().data()[0]], @@ -1114,7 +1108,7 @@ class ConvertNonTrivialConvOp } auto output_type = RankedTensorType::get( transposed_output_shape, - conv_op.getRhs().getType().cast().getElementType()); + mlir::cast(conv_op.getRhs().getType()).getElementType()); auto output_sizes = rewriter.create( conv_op.getLoc(), DenseIntElementsAttr::get( @@ -1138,7 +1132,8 @@ class ConvertNonTrivialConvOp } else { SmallVector output_shape_i32; for (int64_t dim : - conv_op.getResult().getType().cast().getShape()) { + mlir::cast(conv_op.getResult().getType()) + .getShape()) { output_shape_i32.push_back(dim); } auto output_sizes = rewriter.create( @@ -1176,14 +1171,12 @@ class ConvertNonTrivialConvOp for (size_t i = 1; i <= num_spatial_dims; ++i) { int64_t stride = strides[i]; - int64_t input_size = - conv_op.getLhs().getType().cast().getDimSize( - input_spatial_dims[i - 1]); - int64_t kernel_size = - conv_op.getRhs().getType().cast().getDimSize( - kernel_spatial_dims[i - 1]); - int64_t output_size = conv_op.getType().cast().getDimSize( - output_spatial_dims[i - 1]); + int64_t input_size = mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_spatial_dims[i - 1]); + int64_t kernel_size = mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_spatial_dims[i - 1]); + int64_t output_size = mlir::cast(conv_op.getType()) + .getDimSize(output_spatial_dims[i - 1]); // stablehlo.convolution op needs explicit padding to be set to model any // Transposed-Convolution in JAX/PT. Checking to see if- @@ -1225,11 +1218,10 @@ class ConvertNonTrivialConvOp return false; } int64_t stride = strides[i + 1]; - int64_t input_size = - conv_op.getLhs().getType().cast().getDimSize( - input_spatial_dims[i]); - int64_t output_size = conv_op.getType().cast().getDimSize( - output_spatial_dims[i]); + int64_t input_size = mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_spatial_dims[i]); + int64_t output_size = mlir::cast(conv_op.getType()) + .getDimSize(output_spatial_dims[i]); // The reason for the below check is as follows: // When computing the output, we have the following relation between // o - output dim size, i - input dim size, s - stride, P - total pads @@ -1280,13 +1272,11 @@ class ConvertDynamicSliceOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::DynamicSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { - ShapedType input_type = op.getOperand().getType().cast(); + ShapedType input_type = mlir::cast(op.getOperand().getType()); if (!input_type.hasStaticShape()) return failure(); - Type start_indices_element_type = op.getStartIndices() - .front() - .getType() - .cast() - .getElementType(); + Type start_indices_element_type = + mlir::cast(op.getStartIndices().front().getType()) + .getElementType(); // The mhlo dynamic_slice's start_indices can be either signed/unsigned // int32/int64. However, TF only takes in either i32 or i64 types for begin, @@ -1307,8 +1297,8 @@ class ConvertDynamicSliceOp : public OpConversionPattern { for (uint64_t i = 0, e = op.getStartIndices().size(); i < e; ++i) { // Always put a cast there. auto start = op.getStartIndices()[i]; - auto cast_type = start.getType().cast().clone( - signed_start_indices_element_type); + auto cast_type = mlir::cast(start.getType()) + .clone(signed_start_indices_element_type); auto cast_op = rewriter.create(op.getLoc(), cast_type, start); Value clamp_max = rewriter.create( op.getLoc(), rewriter.getIntegerAttr( @@ -1409,11 +1399,11 @@ class ConvertDynamicUpdateSliceOp LogicalResult matchAndRewrite( mhlo::DynamicUpdateSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { - ShapedType operand_type = op.getOperand().getType().cast(); + ShapedType operand_type = mlir::cast(op.getOperand().getType()); ShapedType update_type = - op.getUpdate().getType().dyn_cast_or_null(); - ShapedType start_indices_type = - op.getStartIndices().front().getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op.getUpdate().getType()); + ShapedType start_indices_type = mlir::dyn_cast_or_null( + op.getStartIndices().front().getType()); if (update_type == nullptr || start_indices_type == nullptr) return rewriter.notifyMatchFailure( op, "update and start_indices should have ShapedType"); @@ -1474,8 +1464,8 @@ class ConvertSortToTfTopk : public OpConversionPattern { op, "only match for the case where operands is of size 2"); auto keys = op.getInputs()[0]; auto indices = op.getInputs()[1]; - auto keys_ty = keys.getType().dyn_cast_or_null(); - auto indices_ty = indices.getType().dyn_cast_or_null(); + auto keys_ty = mlir::dyn_cast_or_null(keys.getType()); + auto indices_ty = mlir::dyn_cast_or_null(indices.getType()); if (!keys_ty || !keys_ty.hasStaticShape() || !keys_ty.getElementType().isIntOrFloat()) return rewriter.notifyMatchFailure( @@ -1589,7 +1579,7 @@ Value BuildDotOperandFlattenedShapeOp(Value operand, DotDimensionsInfo dot_dimensions_info, ImplicitLocOpBuilder& builder, bool is_lhs) { - auto operand_type = operand.getType().cast(); + auto operand_type = mlir::cast(operand.getType()); BoolAttr true_attr = builder.getBoolAttr(true); auto operand_shape = builder.create(operand, true_attr); const int64_t operand_rank = operand_type.getRank(); @@ -1665,8 +1655,8 @@ Value BuildDotOperandFlattenedShapeOp(Value operand, Value ConvertDot(PatternRewriter& rewriter, Value lhs, Value rhs, DotDimensionNumbersAttr dot_dimension_numbers, ShapedType result_type, mlir::Location loc) { - auto lhs_type = lhs.getType().cast(); - auto rhs_type = rhs.getType().cast(); + auto lhs_type = mlir::cast(lhs.getType()); + auto rhs_type = mlir::cast(rhs.getType()); const int lhs_rank = lhs_type.getRank(); const int rhs_rank = rhs_type.getRank(); ImplicitLocOpBuilder builder(loc, rewriter); @@ -1821,7 +1811,7 @@ Value ConvertDot(PatternRewriter& rewriter, Value lhs, Value rhs, // necessary. Value ConvertDotOp(PatternRewriter& rewriter, Operation* old_op) { auto dot_op = cast(old_op); - auto lhs_rank = dot_op.getLhs().getType().cast().getRank(); + auto lhs_rank = mlir::cast(dot_op.getLhs().getType()).getRank(); auto dot_dimension_numbers = DotDimensionNumbersAttr::get(rewriter.getContext(), /*lhs_batching_dimensions=*/{}, @@ -1831,17 +1821,18 @@ Value ConvertDotOp(PatternRewriter& rewriter, Operation* old_op) { /*rhs_contracting_dimensions=*/{0}); return ConvertDot( rewriter, dot_op.getLhs(), dot_op.getRhs(), dot_dimension_numbers, - dot_op.getResult().getType().cast(), dot_op.getLoc()); + mlir::cast(dot_op.getResult().getType()), dot_op.getLoc()); } // Converts mhlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be // inserted to convert to well-formed matrix multiply. Value ConvertDotGeneralOp(PatternRewriter& rewriter, Operation* old_op) { auto dot_general_op = cast(old_op); - return ConvertDot(rewriter, dot_general_op.getLhs(), dot_general_op.getRhs(), - dot_general_op.getDotDimensionNumbers(), - dot_general_op.getResult().getType().cast(), - dot_general_op.getLoc()); + return ConvertDot( + rewriter, dot_general_op.getLhs(), dot_general_op.getRhs(), + dot_general_op.getDotDimensionNumbers(), + mlir::cast(dot_general_op.getResult().getType()), + dot_general_op.getLoc()); } // Replace BinaryOp with a combination of TfBinaryOp and TfReduceOp if the @@ -1940,9 +1931,9 @@ class ConvertReduceOpToTfOp : public OpConversionPattern { reduce_op.getResults().size() != 1) return failure(); - if (!reduce_op.getInputs()[0].getType().isa()) + if (!mlir::isa(reduce_op.getInputs()[0].getType())) return failure(); - if (!reduce_op.getType(0).isa()) return failure(); + if (!mlir::isa(reduce_op.getType(0))) return failure(); return success(); } }; @@ -1953,13 +1944,13 @@ class ConvertReduceOpToTfProd using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp; LogicalResult MatchInitValue(Value init_value) const override { - auto type = init_value.getType().cast().getElementType(); - if (type.isa()) { + auto type = mlir::cast(init_value.getType()).getElementType(); + if (mlir::isa(type)) { float const_value; if (failed(GetConstantSplatValue(init_value, const_value)) || const_value != 1.0) return failure(); - } else if (type.isa() && type.isSignlessInteger()) { + } else if (mlir::isa(type) && type.isSignlessInteger()) { int32_t const_value; if (failed(GetConstantSplatValue(init_value, const_value)) || const_value != 1) @@ -1978,13 +1969,13 @@ class ConvertReduceOpToTfSum using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp; LogicalResult MatchInitValue(Value init_value) const override { - auto type = init_value.getType().cast().getElementType(); - if (type.isa()) { + auto type = mlir::cast(init_value.getType()).getElementType(); + if (mlir::isa(type)) { APFloat const_value(.0); if (failed(GetConstantSplatValue(init_value, const_value)) || !const_value.isZero()) return failure(); - } else if (type.isa() && type.isSignlessInteger()) { + } else if (mlir::isa(type) && type.isSignlessInteger()) { APInt const_value; if (failed(GetConstantSplatValue(init_value, const_value)) || !const_value.isZero()) @@ -2003,13 +1994,13 @@ class ConvertReduceOpToTfMax using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp; LogicalResult MatchInitValue(Value init_value) const override { - auto type = init_value.getType().cast().getElementType(); - if (type.isa()) { + auto type = mlir::cast(init_value.getType()).getElementType(); + if (mlir::isa(type)) { APFloat const_value(.0); if (failed(GetConstantSplatValue(init_value, const_value)) || !const_value.isInfinity() || !const_value.isNegative()) return failure(); - } else if (type.isa() && type.isSignlessInteger()) { + } else if (mlir::isa(type) && type.isSignlessInteger()) { APInt const_value; if (failed(GetConstantSplatValue(init_value, const_value)) || !const_value.isMinSignedValue()) @@ -2027,14 +2018,14 @@ class ConvertReduceOpToTfMin using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp; LogicalResult MatchInitValue(Value init_value) const override { - auto type = init_value.getType().cast().getElementType(); + auto type = mlir::cast(init_value.getType()).getElementType(); - if (type.isa()) { + if (mlir::isa(type)) { APFloat const_value(.0); if (failed(GetConstantSplatValue(init_value, const_value)) || !const_value.isInfinity() || const_value.isNegative()) return failure(); - } else if (type.isa() && type.isSignlessInteger()) { + } else if (mlir::isa(type) && type.isSignlessInteger()) { APInt const_value; if (failed(GetConstantSplatValue(init_value, const_value)) || !const_value.isMaxSignedValue()) @@ -2088,7 +2079,7 @@ class ConvertReduceOpToTfArgmax auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat()) return false; - if (element_type.isa()) { + if (mlir::isa(element_type)) { auto value = *attr.value_begin(); return value.isNegative() && value.isInfinity(); } else if (element_type.isInteger(1)) { @@ -2112,7 +2103,7 @@ class ConvertReduceOpToTfArgmin auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat()) return false; - if (element_type.isa()) { + if (mlir::isa(element_type)) { auto value = *attr.value_begin(); return !value.isNegative() && value.isInfinity(); } else if (element_type.isInteger(1)) { @@ -2134,18 +2125,18 @@ class ConvertIotaOpToTfRange : public OpConversionPattern { mhlo::IotaOp iota_op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { RankedTensorType type = - iota_op.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(iota_op.getType()); // TF::RangeOp doesn't support UI16. if (!type || type.getElementType().isUnsignedInteger(16)) return failure(); const uint64_t dimension = iota_op.getIotaDimension(); Type element_type = type.getElementType(); Attribute start, limit, delta; - if (element_type.isa()) { + if (mlir::isa(element_type)) { start = rewriter.getFloatAttr(element_type, 0.0); limit = rewriter.getFloatAttr(element_type, type.getShape()[dimension]); delta = rewriter.getFloatAttr(element_type, 1.0); - } else if (element_type.isa()) { + } else if (mlir::isa(element_type)) { start = rewriter.getIntegerAttr(element_type, 0); limit = rewriter.getIntegerAttr(element_type, type.getShape()[dimension]); delta = rewriter.getIntegerAttr(element_type, 1); @@ -2249,9 +2240,10 @@ bool IsSpatialPoolingWithoutDilation( // Check that the individual padding values are corresponding to SAME // padding from TensorFlow. - auto operand_type = rw.getInputs()[0].getType().dyn_cast(); + auto operand_type = + mlir::dyn_cast(rw.getInputs()[0].getType()); RankedTensorType output_type = - rw.getResult(0).getType().dyn_cast(); + mlir::dyn_cast(rw.getResult(0).getType()); if (!operand_type || !output_type) return false; for (uint64_t i = 1; i < rank - 1; ++i) { @@ -2293,12 +2285,13 @@ class ConvertLoweredCumOp : public OpConversionPattern { auto const_op = llvm::dyn_cast_or_null( rw.getInitValues()[0].getDefiningOp()); if (!const_op) return failure(); - auto const_op_dense_value = const_op.getValue().cast(); + auto const_op_dense_value = + mlir::cast(const_op.getValue()); if (!const_op_dense_value || !IsInitValue(const_op_dense_value)) { return failure(); } - auto operand_type = rw.getInputs()[0].getType().cast(); + auto operand_type = mlir::cast(rw.getInputs()[0].getType()); // For a cumulative op, require a tensor of 1s for each dimension in // operand. @@ -2383,7 +2376,7 @@ class ConvertLoweredCumSumOp auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat()) return false; - if (element_type.isa()) { + if (mlir::isa(element_type)) { auto value = *attr.value_begin(); return value.isZero(); } @@ -2399,7 +2392,7 @@ class ConvertLoweredCumProdOp auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat()) return false; - if (element_type.isa()) { + if (mlir::isa(element_type)) { auto value = *attr.value_begin(); return value.isExactlyValue(1.0); } @@ -2431,8 +2424,8 @@ class ConvertAvgPoolOp : public OpConversionPattern { // Check that this is a floating point reduce window with a rank of 4 or 5. const RankedTensorType rw_type = - rw.getResult(0).getType().dyn_cast(); - if (!rw_type || !rw_type.getElementType().isa() || + mlir::dyn_cast(rw.getResult(0).getType()); + if (!rw_type || !mlir::isa(rw_type.getElementType()) || rw_type.getRank() <= 3 || rw_type.getRank() > 5) return failure(); @@ -2568,8 +2561,8 @@ class ConvertMaxPoolOp : public OpConversionPattern { // Check that this is a floating point reduce window with a rank of 4 or 5. const RankedTensorType rw_type = - rw.getResult(0).getType().dyn_cast(); - if (!rw_type || !rw_type.getElementType().isa() || + mlir::dyn_cast(rw.getResult(0).getType()); + if (!rw_type || !mlir::isa(rw_type.getElementType()) || rw_type.getRank() <= 3 || rw_type.getRank() > 5) return failure(); @@ -2639,7 +2632,7 @@ class ConvertMaxPoolOp : public OpConversionPattern { // Returns the shape of the given value in a Constant Op. arith::ConstantOp ShapeToConst(PatternRewriter& rewriter, Value value) { - ArrayRef shape = value.getType().cast().getShape(); + ArrayRef shape = mlir::cast(value.getType()).getShape(); auto attr_type = RankedTensorType::get({static_cast(shape.size())}, rewriter.getIntegerType(64)); auto attr = DenseElementsAttr::get(attr_type, shape); @@ -2659,36 +2652,37 @@ bool IsSign(APFloat a, APFloat sign) { } bool IsDenseSplatIntAttr(ElementsAttr float_or_int) { - return float_or_int.isa() && - float_or_int.isa(); + return mlir::isa(float_or_int) && + mlir::isa(float_or_int); } bool IsDenseSplatFloatAttr(ElementsAttr float_or_int) { - return float_or_int.isa() && - float_or_int.isa(); + return mlir::isa(float_or_int) && + mlir::isa(float_or_int); } bool ValueIsReciprocal(ElementsAttr float_or_int, ElementsAttr rhs) { if (IsDenseSplatFloatAttr(float_or_int) && IsDenseSplatFloatAttr(float_or_int)) { - return (float_or_int.cast().getSplatValue() * - rhs.cast().getSplatValue()) + return (mlir::cast(float_or_int) + .getSplatValue() * + mlir::cast(rhs).getSplatValue()) .isExactlyValue(1.0); } else if (IsDenseSplatIntAttr(float_or_int) && IsDenseSplatIntAttr(float_or_int)) { - return (float_or_int.cast().getSplatValue() * - rhs.cast().getSplatValue()) == 1; + return (mlir::cast(float_or_int).getSplatValue() * + mlir::cast(rhs).getSplatValue()) == 1; } return false; } bool ValueEquals(ElementsAttr float_or_int, double rhs) { if (IsDenseSplatFloatAttr(float_or_int)) { - return float_or_int.cast() + return mlir::cast(float_or_int) .getSplatValue() .isExactlyValue(rhs); } else if (IsDenseSplatIntAttr(float_or_int)) { - return float_or_int.cast().getSplatValue() == + return mlir::cast(float_or_int).getSplatValue() == static_cast(rhs); } return false; @@ -2696,11 +2690,12 @@ bool ValueEquals(ElementsAttr float_or_int, double rhs) { bool ValueGreaterThanZero(ElementsAttr float_or_int) { if (IsDenseSplatIntAttr(float_or_int)) { - auto value = float_or_int.cast().getSplatValue(); + auto value = + mlir::cast(float_or_int).getSplatValue(); return !value.isNegative() && !value.isZero(); } else if (IsDenseSplatFloatAttr(float_or_int)) { auto value = - float_or_int.cast().getSplatValue(); + mlir::cast(float_or_int).getSplatValue(); return !value.isNaN() && !value.isNegative() && !value.isZero(); } return false; @@ -2723,13 +2718,13 @@ bool TensorIsSign(PatternRewriter& rewriter, ElementsAttr float_or_int, int_spl && sgn_cst_spl) { return IsSign(int_spl.getValue(), sgn_cst_spl.getValue()); } - if (float_or_int.isa()) { + if (mlir::isa(float_or_int)) { auto sgn_splat_value = sgn_splat.getSplatValue(); return llvm::all_of(float_or_int.getValues(), [&](APFloat value) { return IsSign(value, sgn_splat_value); }); } - if (float_or_int.isa()) { + if (mlir::isa(float_or_int)) { auto sgn_splat_value = sgn_splat.getSplatValue(); return llvm::all_of(float_or_int.getValues(), [&](APInt value) { return IsSign(value, sgn_splat_value); @@ -2778,9 +2773,11 @@ class ConvertGatherOp : public OpConversionPattern { Value start_indices = gather_op.getStartIndices(); // Can only convert with static shaped gather. - ShapedType operand_type = operand.getType().cast(); - ShapedType start_indices_type = start_indices.getType().cast(); - ShapedType result_type = gather_op.getResult().getType().cast(); + ShapedType operand_type = mlir::cast(operand.getType()); + ShapedType start_indices_type = + mlir::cast(start_indices.getType()); + ShapedType result_type = + mlir::cast(gather_op.getResult().getType()); if (!operand_type.hasStaticShape()) { gather_op.emitOpError() << "Dynamic shaped operand is not supported."; return failure(); @@ -2917,9 +2914,11 @@ class ConvertGatherOp : public OpConversionPattern { static const int max_batch_size = 50; // Can only convert with static shaped gather. - ShapedType operand_type = operand.getType().cast(); - ShapedType start_indices_type = start_indices.getType().cast(); - ShapedType result_type = gather_op.getResult().getType().cast(); + ShapedType operand_type = mlir::cast(operand.getType()); + ShapedType start_indices_type = + mlir::cast(start_indices.getType()); + ShapedType result_type = + mlir::cast(gather_op.getResult().getType()); if (!operand_type.hasStaticShape() || !start_indices_type.hasStaticShape() || !result_type.hasStaticShape()) { return rewriter.notifyMatchFailure( @@ -3140,7 +3139,7 @@ class ConvertWhileOp : public OpConversionPattern { // This rule doesn't support mhlo::WhileOp with tuple inputs. for (auto type : while_op->getOperandTypes()) { - if (type.isa()) return failure(); + if (mlir::isa(type)) return failure(); } // Creates a TF::WhileRegionOp to replace the mhlo::WhileOp. HLO WhileOp @@ -3296,7 +3295,7 @@ class ConvertCustomCallWithApproxTopK } } auto backend_config = - op.getBackendConfigAttr().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op.getBackendConfigAttr()); if (!backend_config) { return op.emitOpError() << "Missing backend_config attribute"; } @@ -3385,12 +3384,13 @@ class ConvertCustomCallWithApproxTopK << "ApproxTopK takes exactly 1 called_computation."; } mlir::func::FuncOp callee = module_op_->lookupSymbol( - op.getCalledComputations()[0].cast()); + mlir::cast(op.getCalledComputations()[0])); mlir::FunctionType callee_type = callee.getFunctionType(); SmallVector expected_callee_input_types; auto num_inputs = op.getInputs().size() / 2; for (unsigned i = 0; i < num_inputs; ++i) { - auto input_type = op.getOperand(i).getType().dyn_cast(); + auto input_type = + mlir::dyn_cast(op.getOperand(i).getType()); auto scalar = RankedTensorType::get({}, input_type.getElementType()); expected_callee_input_types.push_back(scalar); expected_callee_input_types.push_back(scalar); @@ -3491,12 +3491,10 @@ class ConvertRealDynamicSliceOp LogicalResult matchAndRewrite( mhlo::RealDynamicSliceOp real_dynamic_slice_op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { - auto start_indices_type = real_dynamic_slice_op.getStartIndices() - .getType() - .cast(); - auto end_indices_type = real_dynamic_slice_op.getLimitIndices() - .getType() - .cast(); + auto start_indices_type = mlir::cast( + real_dynamic_slice_op.getStartIndices().getType()); + auto end_indices_type = mlir::cast( + real_dynamic_slice_op.getLimitIndices().getType()); if (start_indices_type.getNumDynamicDims() != 0 || end_indices_type.getNumDynamicDims() != 0) { @@ -3522,7 +3520,7 @@ class ConvertDynamicIotaOp : public OpConversionPattern { mhlo::DynamicIotaOp dynamic_iota_op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { RankedTensorType type = - dynamic_iota_op.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(dynamic_iota_op.getType()); if (!type || type.getElementType().isUnsignedInteger(64)) { return rewriter.notifyMatchFailure(dynamic_iota_op, "TF::RangeOp doesn't support UI64"); @@ -3538,19 +3536,19 @@ class ConvertDynamicIotaOp : public OpConversionPattern { const uint64_t dimension = dynamic_iota_op.getIotaDimension(); Type element_type = type.getElementType(); Attribute start, delta; - if (element_type.isa()) { + if (mlir::isa(element_type)) { start = rewriter.getFloatAttr(element_type, 0.0); delta = rewriter.getFloatAttr(element_type, 1.0); - } else if (element_type.isa()) { + } else if (mlir::isa(element_type)) { start = rewriter.getIntegerAttr(element_type, 0); delta = rewriter.getIntegerAttr(element_type, 1); } else { return failure(); } auto output_shape = dynamic_iota_op.getOperand(); - if (element_type.isa()) { + if (mlir::isa(element_type)) { auto cast_type = - output_shape.getType().cast().clone(element_type); + mlir::cast(output_shape.getType()).clone(element_type); output_shape = rewriter.create(dynamic_iota_op.getLoc(), cast_type, output_shape); } @@ -3581,7 +3579,7 @@ bool IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions, // broadcast_dimensions is an increasing list by definition, thus it suffices // to check the first element. int64_t input_rank = broadcast_dimensions.getNumElements(); - int64_t output_rank = output.getType().cast().getRank(); + int64_t output_rank = mlir::cast(output.getType()).getRank(); return input_rank == 0 || (broadcast_dimensions.getValues()[0].getSExtValue() == output_rank - input_rank); @@ -3606,11 +3604,12 @@ arith::ConstantOp ExpandedShape(PatternRewriter& rewriter, Value input, Value output) { // Initialize expanded shape with output rank and dimensions of 1. SmallVector expanded_shape( - output.getType().cast().getRank(), + mlir::cast(output.getType()).getRank(), /*Value=*/rewriter.getI64IntegerAttr(1)); // Set dimension sizes specified by broadcast_dimensions. - ArrayRef input_shape = input.getType().cast().getShape(); + ArrayRef input_shape = + mlir::cast(input.getType()).getShape(); for (auto x : llvm::enumerate(broadcast_dimensions)) { expanded_shape[x.value().getSExtValue()] = rewriter.getI64IntegerAttr(input_shape[x.index()]); @@ -3627,9 +3626,9 @@ arith::ConstantOp ExpandedShape(PatternRewriter& rewriter, Value input, Value ExpandedDynamicShape(PatternRewriter& rewriter, Value input, DenseIntElementsAttr broadcast_dimensions, Value output) { - assert(output.getType().cast() && + assert(mlir::cast(output.getType()) && "output type must be of ShapedType"); - int64_t output_rank = output.getType().cast().getRank(); + int64_t output_rank = mlir::cast(output.getType()).getRank(); llvm::SmallVector expanded_dimensions; llvm::SmallSet broadcast_dimensions_values; for (auto x : llvm::enumerate(broadcast_dimensions)) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.cc index 9d52ee30dd3ce7..520cff8681156a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.cc @@ -37,7 +37,7 @@ LogicalResult ConvertCustomCallOp::matchAndRewrite( rewriter.getStringAttr(mhlo_custom_call.getCallTargetName())); if (auto bc = mhlo_custom_call.getBackendConfig()) { - if (auto stringattr = bc->dyn_cast_or_null()) { + if (auto stringattr = mlir::dyn_cast_or_null(*bc)) { tfl_custom.setCustomOptionAttr( TFL::ConstBytesAttr::get(rewriter.getContext(), stringattr)); } @@ -53,7 +53,7 @@ LogicalResult ConvertCustomCallOp::matchAndRewrite( std::optional IsCustomCallLegal(mhlo::CustomCallOp op) { if (op.getCallTargetName().starts_with("custom_call.")) { auto bc = op.getBackendConfig(); - if (!bc || bc->isa()) { + if (!bc || mlir::isa(*bc)) { return false; } } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc index ef3337cbca27cd..ccd726d2737f84 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc @@ -169,7 +169,7 @@ Value BuildDotOperandFlattenedShapeOp(Value operand, DotDimensionsInfo dot_dimensions_info, ImplicitLocOpBuilder& builder, bool is_lhs) { - auto operand_type = operand.getType().cast(); + auto operand_type = mlir::cast(operand.getType()); auto operand_shape = builder.create( RankedTensorType::get(static_cast(operand_type.getRank()), builder.getIntegerType(32)), @@ -248,8 +248,8 @@ Value BuildDotOperandFlattenedShapeOp(Value operand, Value ConvertDot(PatternRewriter& rewriter, Value lhs, Value rhs, mhlo::DotDimensionNumbersAttr dot_dimension_numbers, ShapedType result_type, mlir::Location loc) { - auto lhs_type = lhs.getType().cast(); - auto rhs_type = rhs.getType().cast(); + auto lhs_type = mlir::cast(lhs.getType()); + auto rhs_type = mlir::cast(rhs.getType()); const int lhs_rank = lhs_type.getRank(); const int rhs_rank = rhs_type.getRank(); ImplicitLocOpBuilder builder(loc, rewriter); @@ -412,7 +412,7 @@ Value ConvertDot(PatternRewriter& rewriter, Value lhs, Value rhs, // be inserted when necessary. See ConvertDotGeneralOp for additional notes. Value ConvertDotOp(PatternRewriter& rewriter, Operation* old_op) { auto dot_op = cast(old_op); - auto lhs_rank = dot_op.getLhs().getType().cast().getRank(); + auto lhs_rank = mlir::cast(dot_op.getLhs().getType()).getRank(); auto dot_dimension_numbers = mhlo::DotDimensionNumbersAttr::get(rewriter.getContext(), /*lhsBatchingDimensions=*/{}, @@ -422,15 +422,16 @@ Value ConvertDotOp(PatternRewriter& rewriter, Operation* old_op) { /*rhsContractingDimensions=*/{0}); return ConvertDot( rewriter, dot_op.getLhs(), dot_op.getRhs(), dot_dimension_numbers, - dot_op.getResult().getType().cast(), dot_op.getLoc()); + mlir::cast(dot_op.getResult().getType()), dot_op.getLoc()); } Value ConvertDotGeneralOp(PatternRewriter& rewriter, Operation* old_op) { auto dot_general_op = cast(old_op); - return ConvertDot(rewriter, dot_general_op.getLhs(), dot_general_op.getRhs(), - dot_general_op.getDotDimensionNumbers(), - dot_general_op.getResult().getType().cast(), - dot_general_op.getLoc()); + return ConvertDot( + rewriter, dot_general_op.getLhs(), dot_general_op.getRhs(), + dot_general_op.getDotDimensionNumbers(), + mlir::cast(dot_general_op.getResult().getType()), + dot_general_op.getLoc()); } } // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h index bfb705d00813d5..157cb82ce8e94e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h @@ -74,7 +74,7 @@ class ConvertReduceOpToArgMinMax : public OpConversionPattern { if (!MatchIota(reduce_op.getDimensions(), iota)) return failure(); // Match the reduction computation. - const bool is_float = operand_init.getElementType().isa(); + const bool is_float = mlir::isa(operand_init.getElementType()); if (failed(MatchReduceToArgMinMaxType1(reduce_op, is_float, is_argmax)) && failed(MatchReduceToArgMinMaxType2(reduce_op, is_argmax))) return rewriter.notifyMatchFailure( @@ -91,8 +91,8 @@ class ConvertReduceOpToArgMinMax : public OpConversionPattern { // Generate a Max and an ArgMax of as the mhlo op returns both while in TF // we have separate ops for them. If only one of them is used then the other // one will be garbage collected later. - if (!operand.getType().isa()) return failure(); - auto operand_type = operand.getType().cast(); + if (!mlir::isa(operand.getType())) return failure(); + auto operand_type = mlir::cast(operand.getType()); if (operand_type.getElementType().isInteger(1)) { // TF does not support min or max on boolean (int1) arguments. // Use AnyOp for MaxOp and AllOp for MinOp. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h index fb0e0d80a4eb9b..f8ea8227137617 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h @@ -95,9 +95,9 @@ class ConvertScatterOp : public OpConversionPattern { OperandRange updates = scatter_op.getUpdates(); if (operands.size() != 1 || updates.size() != 1) return failure(); - ShapedType operand_type = operands[0].getType().cast(); - ShapedType indices_type = indices.getType().cast(); - ShapedType updates_type = updates[0].getType().cast(); + ShapedType operand_type = mlir::cast(operands[0].getType()); + ShapedType indices_type = mlir::cast(indices.getType()); + ShapedType updates_type = mlir::cast(updates[0].getType()); Value new_updates = updates[0]; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc index c2f533776d0408..783f0431e9b964 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc @@ -203,7 +203,7 @@ Value InsertTranspose(Value value, int batch_dim, int feature_dim, int default_batch_dim, int default_feature_dim, int default_spatial_dim_start, int num_spatial_dims, ConversionPatternRewriter& rewriter) { - auto type = value.getType().cast(); + auto type = mlir::cast(value.getType()); DenseIntElementsAttr permutation; const int spatial_dim_start = spatial_dimensions.front(); if (!NeedsReformatTypeAndPermutation( @@ -224,7 +224,7 @@ Value InsertTranspose(Value value, int batch_dim, int feature_dim, Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) { IntegerType new_ele_type = rewriter.getIntegerType(32); - if (auto shaped_type = val.getType().dyn_cast()) { + if (auto shaped_type = mlir::dyn_cast(val.getType())) { ShapedType new_type = RankedTensorType::get(shaped_type.getShape(), new_ele_type); return rewriter.create(loc, new_type, val); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc index 6e0a3325460b7a..50521a02c7b907 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc @@ -63,13 +63,13 @@ LogicalResult BuildOption(flexbuffers::Builder* fbb, Operation* op, const char* key = pair.getName().data(); const auto attr = pair.getValue(); - if (attr.isa<::mlir::IntegerAttr>()) { - fbb->Int(key, attr.dyn_cast().getInt()); + if (mlir::isa<::mlir::IntegerAttr>(attr)) { + fbb->Int(key, mlir::dyn_cast(attr).getInt()); return success(); } - if (attr.isa<::mlir::FloatAttr>()) { - fbb->Double(key, attr.dyn_cast().getValueAsDouble()); + if (mlir::isa<::mlir::FloatAttr>(attr)) { + fbb->Double(key, mlir::dyn_cast(attr).getValueAsDouble()); return success(); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc index 4cfb0e04e96af4..e699c303bbaac2 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc @@ -45,18 +45,19 @@ struct ReplaceCustomCallWithComposite final LogicalResult matchAndRewrite(mlir::stablehlo::CustomCallOp op, PatternRewriter &rewriter) const override { auto backendConfig = - op->getAttr("composite.backend_config").dyn_cast(); + mlir::dyn_cast(op->getAttr("composite.backend_config")); if (!backendConfig) return op->emitError( "custom_call has no 'composite.backend_config' attribute or the " "attribute is not a dictionary"); - auto name = backendConfig.get("name").dyn_cast(); + auto name = mlir::dyn_cast(backendConfig.get("name")); if (!name) return op->emitError( "backend_config has no 'name' key or the name value is not a string"); - auto attrs = backendConfig.get("attributes").dyn_cast(); + auto attrs = + mlir::dyn_cast(backendConfig.get("attributes")); if (!attrs) return op->emitError( "backend_config has no 'attributes' key or the attributes value is " @@ -66,7 +67,7 @@ struct ReplaceCustomCallWithComposite final if (!calledComputations || calledComputations.size() != 1) return op->emitError("expected exactly one called_computation"); - auto decomposition = calledComputations[0].cast(); + auto decomposition = mlir::cast(calledComputations[0]); auto composite = rewriter.create( op.getLoc(), op.getResultTypes(), op.getOperands(), name.str(), attrs, diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc index cd6199e3c152f7..b3a85259cc482a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc @@ -41,6 +41,7 @@ limitations under the License. #include "stablehlo/transforms/Passes.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/lite/core/macros.h" #define DEBUG_TYPE "compat-passes" @@ -92,7 +93,7 @@ class StablehloToOdmlTypeConverter : public vhlo::VhloTypeConverter { return attr; if (auto stablehlo_attr = - attr.dyn_cast_or_null()) { + mlir::dyn_cast_or_null(attr)) { return vhlo::TypeExtensionsV1Attr::get(stablehlo_attr.getContext(), stablehlo_attr.getBounds()); } @@ -118,7 +119,8 @@ class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter { } Attribute convertEncoding(Attribute attr) const final { - if (auto vhlo_attr = attr.dyn_cast_or_null()) { + if (auto vhlo_attr = + mlir::dyn_cast_or_null(attr)) { return stablehlo::TypeExtensionsAttr::get(vhlo_attr.getContext(), vhlo_attr.getBounds()); } @@ -230,7 +232,7 @@ LogicalResult ApplyVhloToVersionPatterns(ModuleOp module, PassManager pm(module.getContext()); pm.addPass(stablehlo::createVhloToVersionPass({version})); if (failed(pm.run(module))) { - return module->emitError("Failed VHLO to version") << version; + return module->emitError("Failed VHLO to version ") << version; } return success(); } @@ -274,11 +276,11 @@ struct LegalizeStablehloToVhloPass LegalizeStablehloToVhloPass> { void runOnOperation() override { ModuleOp module = getOperation(); - std::string target_version = "0.14.0"; + std::string target_version = tflite_supported_stablehlo_version; VhloToStablehloTypeConverter to_builtin_converter; // StableHLO --> VHLO (allow funcs) - // VHLO -> Downgrade to 0.14.0 + // VHLO -> Downgrade to tflite_supported_stablehlo_version // VHLO Tensor --> Builtin Tensor // Remove cast(tensor->vhlo) -> cast(vhlo->tensor) pattern if (failed(ApplyStablehloToVhloPatterns(module, diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc index f7a136f2259ad2..82c7a4b4687055 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc @@ -105,8 +105,8 @@ void PrintOpStatsPass::runOnOperation() { (dyn_cast_or_null(op)) ? op->getOperand(1) : op->getResult(0); - ShapedType value_shaped_type = - value_for_deducing_op_type.getType().dyn_cast_or_null(); + ShapedType value_shaped_type = mlir::dyn_cast_or_null( + value_for_deducing_op_type.getType()); if (value_shaped_type != nullptr) { auto operand_or_result = value_shaped_type.getElementType(); std::string dtype; @@ -122,15 +122,16 @@ void PrintOpStatsPass::runOnOperation() { }) .Case([&](Type) { auto uniform_quantized_dtype = - operand_or_result.dyn_cast_or_null() + mlir::dyn_cast_or_null( + operand_or_result) .getStorageType(); dtype = absl::StrCat( "uq_", uniform_quantized_dtype.getIntOrFloatBitWidth()); }) .Case([&](Type) { auto uniform_quantized_dtype = - operand_or_result - .dyn_cast_or_null() + mlir::dyn_cast_or_null( + operand_or_result) .getStorageType(); dtype = absl::StrCat( "uq_", uniform_quantized_dtype.getIntOrFloatBitWidth()); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc index b0797521798994..d9c23dfa12b8ae 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -37,8 +38,8 @@ namespace odml { // Convert mhlo.dot to mhlo.dot_general. LogicalResult ConvertDotToDotGeneral(mhlo::DotOp op, PatternRewriter &rewriter) { - auto lhs_type = op.getLhs().getType().cast(); - auto rhs_type = op.getRhs().getType().cast(); + auto lhs_type = mlir::cast(op.getLhs().getType()); + auto rhs_type = mlir::cast(op.getRhs().getType()); if (!lhs_type.hasRank() || !rhs_type.hasRank()) { return rewriter.notifyMatchFailure(op, "unsupported unranked input type"); } @@ -264,7 +265,7 @@ LogicalResult LiftDotConcatLHS(mhlo::ConcatenateOp concat, new_concat_shape[new_concat_dim] = 0; for (auto v : all_dot_lhs) { new_concat_shape[new_concat_dim] += - v.getType().dyn_cast().getShape()[new_concat_dim]; + mlir::dyn_cast(v.getType()).getShape()[new_concat_dim]; } auto new_concat = rewriter.create( @@ -353,7 +354,7 @@ LogicalResult LiftDotConcatLHSAndRHS(mhlo::ConcatenateOp concat, lhs_new_concat_shape[lhs_batch_dim] = 0; for (auto v : all_dot_lhs) { lhs_new_concat_shape[lhs_batch_dim] += - v.getType().dyn_cast().getShape()[lhs_batch_dim]; + mlir::dyn_cast(v.getType()).getShape()[lhs_batch_dim]; } const int64_t rhs_batch_dim = first_dot.getDotDimensionNumbers().getRhsBatchingDimensions()[0]; @@ -362,7 +363,7 @@ LogicalResult LiftDotConcatLHSAndRHS(mhlo::ConcatenateOp concat, rhs_new_concat_shape[rhs_batch_dim] = 0; for (auto v : all_dot_rhs) { rhs_new_concat_shape[rhs_batch_dim] += - v.getType().dyn_cast().getShape()[rhs_batch_dim]; + mlir::dyn_cast(v.getType()).getShape()[rhs_batch_dim]; } auto lhs_new_concat = rewriter.create( diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc new file mode 100644 index 00000000000000..11cb7254b75c76 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc @@ -0,0 +1,210 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for legalizing HLO to TensorFlow. + +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Utils/IndexingUtils.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" + +namespace mlir { +namespace odml { +namespace { + +#define DEBUG_TYPE "stablehlo-optimize-layout" + +#define GEN_PASS_DEF_TRANSPOSECOMMUTEOPSPASS +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" + +class TransposeCommuteOpsPass + : public impl::TransposeCommuteOpsPassBase { + void runOnOperation() override; +}; + +// Inversely permutate a given vector +static SmallVector InvertPermutationToVector(ArrayRef vec, + ArrayRef perm) { + return applyPermutation(vec, invertPermutationVector(perm)); +} + +static RankedTensorType GetPermutedTensorTypeHelper(RankedTensorType type, + ArrayRef perm, + bool isInvert) { + SmallVector permutedShape = applyPermutation( + type.getShape(), isInvert ? invertPermutationVector(perm) : perm); + return RankedTensorType::get(permutedShape, type.getElementType()); +} + +static RankedTensorType GetInvertPermutedTensorType(RankedTensorType type, + ArrayRef perm) { + return GetPermutedTensorTypeHelper(type, perm, true /*isInvert*/); +} + +static Value CreateTranspose(OpBuilder& builder, Value source, + ArrayRef perm) { + return builder.create(source.getLoc(), source, perm) + ->getResult(0); +} + +// Transform pad(transpose(x)) to transpose(pad(x)) +struct TransposeCommuteWithPad : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::PadOp pad_op, + PatternRewriter& rewriter) const override { + Value pad_input = pad_op.getOperand(); + RankedTensorType pad_type = pad_op.getType().cast(); + + auto transpose_op = pad_input.getDefiningOp(); + if (!transpose_op || !transpose_op->hasOneUse()) return failure(); + Value transpose_input = transpose_op.getOperand(); + + ArrayRef transpose_perm = transpose_op.getPermutation(); + SmallVector new_padding_low = + InvertPermutationToVector(pad_op.getEdgePaddingLow(), transpose_perm); + SmallVector new_padding_high = + InvertPermutationToVector(pad_op.getEdgePaddingHigh(), transpose_perm); + SmallVector new_padding_interrier = + InvertPermutationToVector(pad_op.getInteriorPadding(), transpose_perm); + + RankedTensorType new_pad_type = + GetInvertPermutedTensorType(pad_type, transpose_perm); + Value new_pad = rewriter.create( + pad_op.getLoc(), new_pad_type, transpose_input, + pad_op.getPaddingValue(), new_padding_low, new_padding_high, + new_padding_interrier); + + Value orig_pad = CreateTranspose(rewriter, new_pad, transpose_perm); + rewriter.replaceOp(pad_op, orig_pad); + return success(); + } +}; + +// Transform reduce_window(transpose(x)) to transpose(reduce_window(x)) +struct TransposeCommuteWithReduceWindow + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::ReduceWindowOp reduce_op, + PatternRewriter& rewriter) const override { + MLIRContext* ctx = reduce_op.getContext(); + ValueRange inputs = reduce_op.getInputs(); + // Only handle binary reduce ops for now + if (inputs.size() != 1) return failure(); + Value reduce_input = inputs[0]; + + RankedTensorType reduce_type = + reduce_op.getResultTypes()[0].cast(); + + auto transpose_op = reduce_input.getDefiningOp(); + if (!transpose_op || !transpose_op->hasOneUse()) return failure(); + Value transpose_input = transpose_op.getOperand(); + + ArrayRef transpose_perm = transpose_op.getPermutation(); + + // Inversely transposes all the attributes to prepare for the new reduce op + auto getInvertPermutedAttr = + [&](std::optional> vals) -> DenseI64ArrayAttr { + return vals.has_value() + ? DenseI64ArrayAttr::get( + ctx, InvertPermutationToVector(*vals, transpose_perm)) + : nullptr; + }; + DenseI64ArrayAttr new_window_dimensions = + getInvertPermutedAttr(reduce_op.getWindowDimensions()); + DenseI64ArrayAttr new_window_strides = + getInvertPermutedAttr(reduce_op.getWindowStrides()); + DenseI64ArrayAttr new_base_dilations = + getInvertPermutedAttr(reduce_op.getBaseDilations()); + DenseI64ArrayAttr new_win_dilations = + getInvertPermutedAttr(reduce_op.getWindowDilations()); + + auto padding = reduce_op.getPadding(); + int64_t rank = transpose_perm.size(); + DenseIntElementsAttr new_padding_attr = nullptr; + if (padding.has_value()) { + SmallVector new_padding(rank * 2, 0); + auto old_padding = (*padding).getValues(); + for (int64_t idx = 0; idx < rank; ++idx) { + new_padding[2 * transpose_perm[idx]] = old_padding[2 * idx]; + new_padding[2 * transpose_perm[idx] + 1] = old_padding[2 * idx + 1]; + } + new_padding_attr = + DenseIntElementsAttr::get((*padding).getType(), new_padding); + } + + RankedTensorType new_reduce_type = + GetInvertPermutedTensorType(reduce_type, transpose_perm); + auto new_reduce_op = rewriter.create( + reduce_op.getLoc(), new_reduce_type, transpose_input, + reduce_op.getInitValues()[0], new_window_dimensions, new_window_strides, + new_base_dilations, new_win_dilations, new_padding_attr); + IRMapping mapping; + reduce_op.getBody().cloneInto(&new_reduce_op.getBody(), mapping); + + Value orig_reduce_op = + CreateTranspose(rewriter, new_reduce_op->getResult(0), transpose_perm); + rewriter.replaceOp(reduce_op, orig_reduce_op); + return success(); + } +}; + +void TransposeCommuteOpsPass::runOnOperation() { + auto* ctx = &getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // end namespace + +std::unique_ptr> CreateTransposeCommuteOpsPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // end namespace odml +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h index 49e8b673f63374..6c8e587871d393 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h @@ -49,6 +49,9 @@ CreateComposeUniformQuantizedTypePass(); std::unique_ptr> CreateUniformQuantizedStableHloToTflPass(); +// Create a pass that commute transposes through specific ops +std::unique_ptr> CreateTransposeCommuteOpsPass(); + // Create a pass that legalizes MHLO to TF dialect. std::unique_ptr> CreateLegalizeHloToTfPass(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td index a535d3aa867c80..1a1d7335b0517e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td @@ -114,3 +114,9 @@ def CompositeLoweringPass : Pass<"composite-lowering", "ModuleOp"> { let dependentDialects = ["mlir::mhlo::MhloDialect", "TFL::TensorFlowLiteDialect"]; let constructor = "mlir::odml::CreateCompositeLoweringPass()"; } + +def TransposeCommuteOpsPass : Pass<"transpose-commute-ops", "ModuleOp"> { + let summary = "Move transpose through specific ops"; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; + let constructor = "mlir::odml::CreateTransposeCommuteOpsPass()"; +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc index 72ae7cc1c0047d..81c6fc47473d43 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc @@ -18,10 +18,12 @@ limitations under the License. #include #include +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" namespace mlir { @@ -61,21 +63,40 @@ class RenameEntrypointToMainPass // } // clang-format on for (auto attr : session_initializer.getInitializers()) { - auto sym_attr = attr.dyn_cast(); + auto sym_attr = mlir::dyn_cast(attr); if (!sym_attr) break; entrypoints.erase(sym_attr.getValue()); } } if (entrypoints.empty()) { - fail(module, "No entrypoints found"); - } else if (entrypoints.size() == 1) { + return fail(module, "No entrypoints found"); + } + if (entrypoints.size() == 1) { auto entrypoint = entrypoints.begin()->second; Builder builder(entrypoint); entrypoint.setName(builder.getStringAttr("main")); - } else { - fail(module, "Too many entrypoints found"); + return; + } + + // In case we have more than 1 entry points, choose the one with + // 'tf.entry_function' attribute set. + llvm::SmallVector candidate_funcs; + for (auto& entrypoint : entrypoints) { + if (entrypoint.second->hasAttr("tf.entry_function")) { + candidate_funcs.push_back(entrypoint.second); + } + } + + if (candidate_funcs.empty()) { + return fail(module, "No entrypoints found"); + } + if (candidate_funcs.size() > 1) { + return fail(module, "Too many entrypoints found"); } + // Found entrypoint + Builder builder(candidate_funcs[0]); + candidate_funcs[0].setName(builder.getStringAttr("main")); } }; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc index 4304d34f4743ec..f86b78275fb951 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc @@ -66,7 +66,7 @@ class ConvertReduceOpToTFLiteArgmax auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat()) return false; - if (element_type.isa()) { + if (mlir::isa(element_type)) { auto value = *attr.value_begin(); return value.isNegative() && value.isInfinity(); } else if (element_type.isInteger(1)) { @@ -90,7 +90,7 @@ class ConvertReduceOpToTFLiteArgmin auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat()) return false; - if (element_type.isa()) { + if (mlir::isa(element_type)) { auto value = *attr.value_begin(); return !value.isNegative() && value.isInfinity(); } else if (element_type.isInteger(1)) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc index 7eb3abdef793eb..fdbf12538f230e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc @@ -77,7 +77,7 @@ void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize, // TF -> StableHLO legalization. AddLegalizeTFToStablehloPasses(pm, /*skip_quantization_ops=*/false, skip_resize, - /*skip_stateful_partitioned_call=*/false); + /*skip_partitioned_calls=*/false); // Wrap disallowed ops in stablehlo.custom_call ops. if (smuggle_disallowed_ops) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfold_splat_constant_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfold_splat_constant_pass.cc index e7f86a022d2274..7a3abd35d0d376 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfold_splat_constant_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfold_splat_constant_pass.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { @@ -60,7 +61,7 @@ class UnfoldSplatConstantPass void UnfoldSplatConstant(mlir::OpBuilder* op_builder, mhlo::ConstantOp const_op) const { auto splat_elements_attr = - const_op.getValue().dyn_cast(); + mlir::dyn_cast(const_op.getValue()); if (!splat_elements_attr) { return; } @@ -68,8 +69,8 @@ class UnfoldSplatConstantPass return; } auto element_type = splat_elements_attr.getType().getElementType(); - if (element_type.isa() || - element_type.isa()) { + if (mlir::isa(element_type) || + mlir::isa(element_type)) { return; } op_builder->setInsertionPoint(const_op); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc index f4cdad00b79774..dadcabc55a5e57 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" @@ -60,7 +61,8 @@ Value broadcastToFeatureDim(Location loc, RankedTensorType result_type, // Gets the shape of operand, assuming it is a dynamic shape with static rank. Value getShapeValue(Location loc, Value operand, PatternRewriter &rewriter) { - RankedTensorType resultType = operand.getType().dyn_cast(); + RankedTensorType resultType = + mlir::dyn_cast(operand.getType()); return rewriter.create( loc, RankedTensorType::get(/*shape=*/{resultType.getRank()}, @@ -92,8 +94,8 @@ Value materializeEpsilon(Operation *op, FloatAttr epsilon_attr, } auto scalar_type = RankedTensorType::get(/*shape=*/{}, fp_type); - auto epsilon_tensor_attr = - DenseElementsAttr::get(scalar_type, {epsilon_attr.cast()}); + auto epsilon_tensor_attr = DenseElementsAttr::get( + scalar_type, {mlir::cast(epsilon_attr)}); Value epsilon = b.create(epsilon_tensor_attr); auto dims_type = RankedTensorType::get(/*shape=*/{0}, b.getIntegerType(64)); auto dims = DenseIntElementsAttr::get(dims_type, SmallVector{}); @@ -113,7 +115,7 @@ class UnfuseBatchNormTrainingPattern LogicalResult matchAndRewrite(mhlo::BatchNormTrainingOp bn_op, PatternRewriter &rewriter) const override { auto inputs = bn_op.getOperand(); - auto input_type = inputs.getType().dyn_cast(); + auto input_type = mlir::dyn_cast(inputs.getType()); if (!input_type) { return failure(); } @@ -172,13 +174,14 @@ class UnfuseBatchNormInferencePattern // Enforce type invariants. // Note that we deduce the actual element type from the variance, // which should not be subject to quantization at a higher level. - auto input_type = bn_op.getOperand().getType().dyn_cast(); + auto input_type = + mlir::dyn_cast(bn_op.getOperand().getType()); auto variance_type = - bn_op.getVariance().getType().dyn_cast(); + mlir::dyn_cast(bn_op.getVariance().getType()); if (!input_type || !variance_type) { return failure(); } - auto fp_type = variance_type.getElementType().dyn_cast(); + auto fp_type = mlir::dyn_cast(variance_type.getElementType()); if (!fp_type) { return failure(); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index ad3bc3cd4cd24d..c3a05d5a0706a7 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -310,6 +310,137 @@ Type GetQuantizedOutputType(Operation* op, PatternRewriter& rewriter, new_result_quantized_type); } +// Matches kernel dimension numbers, ranks of input and output and constant +// kernel for legalization to TFLite convolution ops. +LogicalResult MatchConvolutionFormat(stablehlo::ConvolutionOp op) { + stablehlo::ConvDimensionNumbersAttr dimension_numbers = + op.getDimensionNumbers(); + const int64_t kernel_input_feature_dim = + dimension_numbers.getKernelInputFeatureDimension(); + if (kernel_input_feature_dim != 2) { + LLVM_DEBUG(llvm::dbgs() << "Expected kernel input feature == 2. Got: " + << kernel_input_feature_dim << ".\n"); + return failure(); + } + + const int64_t kernel_output_feature_dim = + dimension_numbers.getKernelOutputFeatureDimension(); + if (kernel_output_feature_dim != 3) { + LLVM_DEBUG(llvm::dbgs() << "Expected kernel output feature == 3. Got: " + << kernel_output_feature_dim << ".\n"); + return failure(); + } + + const auto input_type = op.getLhs().getType().cast(); + if (input_type.getRank() != 4) { + LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " + "Expected input rank of 4. Got: " + << input_type.getRank() << ".\n"); + return failure(); + } + + const auto filter_type = op.getRhs().getType().cast(); + if (filter_type.getRank() != 4) { + LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " + "Expected filter rank of 4. Got: " + << filter_type.getRank() << ".\n"); + return failure(); + } + + if (Operation* filter_op = op.getRhs().getDefiningOp(); + filter_op == nullptr || !isa(filter_op)) { + LLVM_DEBUG(llvm::dbgs() << "Filter should be a constant.\n"); + return failure(); + } + + return success(); +} + +// Transposes the convolution filter tensor of format [0, 1, i, o] to match the +// filter tensor format for TFLite convolution. The following transformations +// are supported: +// +// Depthwise case (`feature_group_count` > 1) +// * Permutates given filter to `[i, 0, 1, o]` format. +// General convolution (`feature_group_count` = 1) +// * Permutates given filter to `[o, 0, 1, i]` format. +// Using TransposeOp doesn't work because the quantized dimension +// changes which violates the constraint for the TransposeOp that the +// input's and output's element type should be the same. +DenseIntElementsAttr TransposeFilterInConvolution( + Location loc, PatternRewriter& rewriter, + const DenseIntElementsAttr& filter_value_attr, const bool is_depthwise) { + ArrayRef filter_shape = filter_value_attr.getShapedType().getShape(); + SmallVector filter_constant_values{ + filter_value_attr.getValues()}; + SmallVector new_filter_constant_values(filter_constant_values.size(), + 0); + SmallVector transpose_dims; + if (is_depthwise) { + transpose_dims = {2, 0, 1, 3}; + } else { + transpose_dims = {3, 0, 1, 2}; + } + + SmallVector new_filter_shape; + new_filter_shape.reserve(filter_shape.size()); + for (int i = 0; i < filter_shape.size(); ++i) { + new_filter_shape.push_back(filter_shape[transpose_dims[i]]); + } + + auto get_array_idx = [](ArrayRef shape, const int i, const int j, + const int k, const int l) -> int64_t { + return (i * shape[1] * shape[2] * shape[3]) + (j * shape[2] * shape[3]) + + (k * shape[3]) + l; + }; + + // Transpose the filter value. + // TODO: b/336203735 - Use `DenseElementsTransposer` instead of manual + // transpose. + for (int i = 0; i < filter_shape[0]; ++i) { + for (int j = 0; j < filter_shape[1]; ++j) { + for (int k = 0; k < filter_shape[2]; ++k) { + for (int l = 0; l < filter_shape[3]; ++l) { + // [o, 0, 1, i] for `tfl.conv_2d` case`, + // [i, 0, 1, o] for `tfl.depthwise_conv_2d` case. + int old_idx = get_array_idx(filter_shape, i, j, k, l); + int new_idx = is_depthwise + ? get_array_idx(new_filter_shape, k, i, j, l) + : get_array_idx(new_filter_shape, l, i, j, k); + new_filter_constant_values[new_idx] = filter_constant_values[old_idx]; + } + } + } + } + + // Create the new filter constant. + auto new_filter_value_attr_type = + RankedTensorType::getChecked(loc, new_filter_shape, + /*elementType=*/rewriter.getI8Type()); + auto new_filter_constant_value_attr = DenseIntElementsAttr::get( + new_filter_value_attr_type, new_filter_constant_values); + + return new_filter_constant_value_attr; +} + +// Checks if the given convolution op is depthwise. +bool IsDepthwiseConvolution(stablehlo::ConvolutionOp op) { + // `feature_group_count` controls how the input channel dimension is + // split. + // A value bigger than one signals depthwise convolution behavior. + return op.getFeatureGroupCount() > 1; +} + +// Returns kernel output feature dimension of TFLite convolutions. +int64_t GetConvolutionKernelOutputFeatureDimension(bool is_depthwise) { + return is_depthwise ? 3 : 0; +} + +// Returns kernel input feature dimension of TFLite convolutions. +int64_t GetConvolutionKernelInputFeatureDimension(bool is_depthwise) { + return is_depthwise ? 0 : 3; +} + // stablehlo.uniform_quantize -> tfl.quantize // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. class RewriteUniformQuantizeOp @@ -881,24 +1012,6 @@ class RewriteQuantizedConvolutionOp IsI32F32UniformQuantizedPerAxisType(GetElementType(op.getResult())); const bool fuse_bias_constant = FindUserOfType(op) && has_i32_output; - stablehlo::ConvDimensionNumbersAttr dimension_numbers = - op.getDimensionNumbers(); - - const int64_t kernel_input_feature_dim = - dimension_numbers.getKernelInputFeatureDimension(); - if (kernel_input_feature_dim != 2) { - LLVM_DEBUG(llvm::dbgs() << "Expected kernel input feature == 2. Got: " - << kernel_input_feature_dim << ".\n"); - return failure(); - } - - const int64_t kernel_output_feature_dim = - dimension_numbers.getKernelOutputFeatureDimension(); - if (kernel_output_feature_dim != 3) { - LLVM_DEBUG(llvm::dbgs() << "Expected kernel output feature == 3. Got: " - << kernel_output_feature_dim << ".\n"); - return failure(); - } if (failed(MatchInput(op.getOperand(0)))) { LLVM_DEBUG(llvm::dbgs() @@ -918,6 +1031,12 @@ class RewriteQuantizedConvolutionOp return failure(); } + if (failed(MatchConvolutionFormat(op))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match dimension format for convolution_op.\n"); + return failure(); + } + if (fuse_bias_constant) { Operation* add_op = FindUserOfType(op); if (add_op == nullptr) { @@ -941,7 +1060,7 @@ class RewriteQuantizedConvolutionOp stablehlo::ConvDimensionNumbersAttr dimension_numbers = op.getDimensionNumbers(); - const bool is_depthwise = IsDepthwiseConvolution(op, dimension_numbers); + const bool is_depthwise = IsDepthwiseConvolution(op); const bool is_transpose_conv = IsTransposeConv(op, dimension_numbers); const bool fuse_bias_constant = FindUserOfType(op) && has_i32_output; @@ -1029,13 +1148,6 @@ class RewriteQuantizedConvolutionOp private: static LogicalResult MatchInput(Value input) { auto input_type = input.getType().cast(); - if (input_type.getRank() != 4) { - LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " - "Expected input rank of 4. Got: " - << input_type.getRank() << ".\n"); - return failure(); - } - if (const auto input_element_type = input_type.getElementType(); !IsI8F32UniformQuantizedType(input_element_type)) { LLVM_DEBUG(llvm::dbgs() @@ -1049,13 +1161,6 @@ class RewriteQuantizedConvolutionOp static LogicalResult MatchFilter(Value filter) { auto filter_type = filter.getType().cast(); - if (filter_type.getRank() != 4) { - LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " - "Expected filter rank of 4. Got: " - << filter_type.getRank() << ".\n"); - return failure(); - } - const Type filter_element_type = filter_type.getElementType(); if (!IsI8F32UniformQuantizedPerAxisType(filter_type.getElementType())) { LLVM_DEBUG( @@ -1071,12 +1176,6 @@ class RewriteQuantizedConvolutionOp << filter_element_type << "\n"); return failure(); } - - if (Operation* filter_op = filter.getDefiningOp(); - filter_op == nullptr || !isa(filter_op)) { - LLVM_DEBUG(llvm::dbgs() << "Filter should be a constant.\n"); - return failure(); - } return success(); } @@ -1155,76 +1254,6 @@ class RewriteQuantizedConvolutionOp return padded_shape; } - // Transposes the filter tensor to match the filter tensor format for - // TFLite convolution. The following transformations are supported: - // - // Depthwise case (`feature_group_count` > 1) - // * Permutates given filter to `[i, 0, 1, o]` format. - // General convolution (`feature_group_count` = 1) - // * Permutates given filter to `[o, 0, 1, i]` format. - // Using TransposeOp doesn't work because the quantized dimension - // changes which violates the constraint for the TransposeOp that the - // input's and output's element type should be the same. - DenseIntElementsAttr TransposeFilterValue( - Location loc, PatternRewriter& rewriter, - const DenseIntElementsAttr& filter_value_attr, - const bool is_depthwise) const { - ArrayRef filter_shape = - filter_value_attr.getShapedType().getShape(); - SmallVector filter_constant_values; - for (auto filter_val : filter_value_attr.getValues()) { - filter_constant_values.push_back(filter_val); - } - - SmallVector new_filter_constant_values( - filter_constant_values.size(), 0); - - SmallVector new_filter_shape; - SmallVector transpose_dims; - if (is_depthwise) { - transpose_dims = {2, 0, 1, 3}; - } else { - transpose_dims = {3, 0, 1, 2}; - } - for (int i = 0; i < filter_shape.size(); ++i) { - new_filter_shape.push_back(filter_shape[transpose_dims[i]]); - } - - auto get_array_idx = [](ArrayRef shape, const int i, const int j, - const int k, const int l) -> int64_t { - return (i * shape[1] * shape[2] * shape[3]) + (j * shape[2] * shape[3]) + - (k * shape[3]) + l; - }; - - // Transpose the filter value. - for (int i = 0; i < filter_shape[0]; ++i) { - for (int j = 0; j < filter_shape[1]; ++j) { - for (int k = 0; k < filter_shape[2]; ++k) { - for (int l = 0; l < filter_shape[3]; ++l) { - // [o, 0, 1, i] for `tfl.conv_2d` case`, - // [i, 0, 1, o] for `tfl.depthwise_conv_2d` case. - int old_idx = get_array_idx(filter_shape, i, j, k, l); - int new_idx = is_depthwise - ? get_array_idx(new_filter_shape, k, i, j, l) - : get_array_idx(new_filter_shape, l, i, j, k); - - new_filter_constant_values[new_idx] = - filter_constant_values[old_idx]; - } - } - } - } - - // Create the new filter constant. - auto new_filter_value_attr_type = - RankedTensorType::getChecked(loc, new_filter_shape, - /*elementType=*/rewriter.getI8Type()); - auto new_filter_constant_value_attr = DenseIntElementsAttr::get( - new_filter_value_attr_type, new_filter_constant_values); - - return new_filter_constant_value_attr; - } - std::pair GetDimSize( const ArrayRef shape, const ArrayRef indexes) const { return {shape[indexes[0]], shape[indexes[1]]}; @@ -1335,12 +1364,13 @@ class RewriteQuantizedConvolutionOp // Returns the stride amount for the height and width, respectively. std::pair GetStrides(stablehlo::ConvolutionOp op) const { - DenseI64ArrayAttr window_strides_attr = op.getWindowStridesAttr(); - if (!window_strides_attr) { + std::optional> window_strides_attr = + op.getWindowStrides(); + if (!window_strides_attr.has_value()) { return {1, 1}; // Default values. } - auto window_strides_attr_value = window_strides_attr.asArrayRef(); + auto window_strides_attr_value = window_strides_attr.value(); // It is guaranteed from the spec that it has two values: // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution. return {window_strides_attr_value[0], window_strides_attr_value[1]}; @@ -1349,12 +1379,12 @@ class RewriteQuantizedConvolutionOp // Returns the dilation amount for the height and width, respectively. std::pair GetDilationFactors( stablehlo::ConvolutionOp op) const { - DenseI64ArrayAttr lhs_dilation_attr = op.getLhsDilationAttr(); - if (!lhs_dilation_attr) { + std::optional> lhs_dilation_attr = op.getLhsDilation(); + if (!lhs_dilation_attr.has_value()) { return {1, 1}; // Default values. } - auto lhs_dilation_attr_value = lhs_dilation_attr.asArrayRef(); + auto lhs_dilation_attr_value = lhs_dilation_attr.value(); // It is guaranteed from the spec that it has two values: // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution. return {lhs_dilation_attr_value[0], lhs_dilation_attr_value[1]}; @@ -1371,8 +1401,10 @@ class RewriteQuantizedConvolutionOp auto filter_constant_value_attr = cast( cast(filter_value.getDefiningOp()).getValue()); const DenseIntElementsAttr new_filter_value_attr = - TransposeFilterValue(filter_op->getLoc(), rewriter, - filter_constant_value_attr, is_depthwise); + TransposeFilterInConvolution(filter_op->getLoc(), rewriter, + filter_constant_value_attr, is_depthwise); + int64_t kernel_output_feature_dim = + GetConvolutionKernelOutputFeatureDimension(is_depthwise); // Create a new quantized tensor type for the filter. This is required // because the quantized dimension is changed from 3 -> 0. `TFL::Conv2DOp` // requires the quantized dimension to be 0 because it accepts a filter @@ -1383,14 +1415,15 @@ class RewriteQuantizedConvolutionOp auto new_filter_quantized_type = CreateI8F32UniformQuantizedPerAxisType( filter_op->getLoc(), *op.getContext(), filter_uniform_quantized_type.getScales(), - filter_uniform_quantized_type.getZeroPoints(), is_depthwise ? 3 : 0, + filter_uniform_quantized_type.getZeroPoints(), + /*quantization_dimension=*/kernel_output_feature_dim, /*narrow_range=*/true); const auto new_filter_result_type = RankedTensorType::getChecked( filter_op->getLoc(), /*shape=*/new_filter_value_attr.getShapedType().getShape(), /*type=*/new_filter_quantized_type); const int64_t num_output_features = - new_filter_result_type.getShape()[is_depthwise ? 3 : 0]; + new_filter_result_type.getShape()[kernel_output_feature_dim]; new_filter_constant_op = rewriter.create( filter_op->getLoc(), /*output=*/TypeAttr::get(new_filter_result_type), new_filter_value_attr); @@ -1441,15 +1474,6 @@ class RewriteQuantizedConvolutionOp } return bias; } - - bool IsDepthwiseConvolution( - stablehlo::ConvolutionOp op, - const stablehlo::ConvDimensionNumbersAttr dimension_numbers) const { - // `feature_group_count` controls how the input channel dimension is - // split. - // A value bigger than one signals depthwise convolution behavior. - return op.getFeatureGroupCount() > 1; - } }; // Rewrites quantized `stablehlo.transpose` to `tfl.transpose`. @@ -2124,28 +2148,27 @@ class RewriteQuantizedConstantOp } }; -// Splits dot-like hybrid quantized StableHLO ops into `tfl.dequantize` and -// float StableHLO op. Legalization of float StableHLO op depends on existing -// passes for conversion of StableHLO -> MHLO -> TF -> TFL. -template -class RewriteHybridQuantizedDotLikeOp : public OpRewritePattern { +// Splits hybrid quantized `stablehlo.dot_general` into `tfl.dequantize` and +// float `stablehlo.dot_general` op. Legalization of float +// `stablehlo.dot_general` op relies on existing passes for conversion of +// StableHLO -> MHLO -> TF -> TFL. +class RewriteHybridQuantizedDotGeneralOp + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(OpType op) const override { - if (op->getNumOperands() != 2 || op->getNumResults() != 1) { - return failure(); - } + LogicalResult match(stablehlo::DotGeneralOp op) const override { // Lhs and result should not be quantized and rhs should be quantized. return success(!IsQuantizedTensorType(op->getOperand(0).getType()) && IsQuantizedTensorType(op->getOperand(1).getType()) && !IsQuantizedTensorType(op->getResult(0).getType())); } - void rewrite(OpType op, PatternRewriter& rewriter) const override { - Value rhs = op.getOperand(1); + void rewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const override { + Value rhs = op.getRhs(); Type lhs_element_type = - op.getOperand(0).getType().template cast().getElementType(); + op.getLhs().getType().template cast().getElementType(); Type dequantized_rhs_type = quant::CloneTypeWithNewElementType(rhs.getType(), lhs_element_type); auto dq = rewriter.create( @@ -2155,17 +2178,135 @@ class RewriteHybridQuantizedDotLikeOp : public OpRewritePattern { } }; +// Splits hybrid quantized `stablehlo.convolution` into `tfl.dequantize` and +// float `stablehlo.convolution` op. Weight tensor is transposed to match the +// filter tensor format for TFLite convolution. +// Legalization of float `stablehlo.convolution` op relies on existing passes +// for conversion of StableHLO -> MHLO -> TF -> TFL. +class RewriteHybridQuantizedConvolutionOp + : public OpRewritePattern { + public: + explicit RewriteHybridQuantizedConvolutionOp(MLIRContext* ctx) + : OpRewritePattern(ctx, /*benefit=*/5) {} + + LogicalResult match(stablehlo::ConvolutionOp op) const override { + if (failed(MatchConvolutionFormat(op))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match dimension format for convolution_op.\n"); + return failure(); + } + // Lhs and result should not be quantized and rhs should be quantized. + return success(!IsQuantizedTensorType(op->getOperand(0).getType()) && + IsQuantizedTensorType(op->getOperand(1).getType()) && + !IsQuantizedTensorType(op->getResult(0).getType())); + } + + void rewrite(stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const override { + const bool is_depthwise = IsDepthwiseConvolution(op); + + Operation* filter_op = op.getRhs().getDefiningOp(); + auto filter_constant_value_attr = cast( + cast(filter_op).getValue()); + const DenseIntElementsAttr new_filter_value_attr = + TransposeFilterInConvolution(filter_op->getLoc(), rewriter, + filter_constant_value_attr, is_depthwise); + + Type new_filter_type = GetNewWeightQuantizedType( + /*context=*/op.getContext(), /*location=*/filter_op->getLoc(), + /*new_shape=*/new_filter_value_attr.getShapedType().getShape(), + /*filter_type=*/op.getRhs().getType(), is_depthwise); + auto new_filter = rewriter.create( + filter_op->getLoc(), + /*output=*/TypeAttr::get(new_filter_type), new_filter_value_attr); + stablehlo::ConvDimensionNumbersAttr new_dimension_numbers = + GetTflDimensionNumbers(rewriter.getContext(), op.getDimensionNumbers(), + is_depthwise); + op.setDimensionNumbersAttr(new_dimension_numbers); + + Type lhs_element_type = + op.getOperand(0).getType().template cast().getElementType(); + Type dequantized_rhs_type = quant::CloneTypeWithNewElementType( + new_filter.getType(), lhs_element_type); + auto dq = rewriter.create( + op->getLoc(), /*output=*/dequantized_rhs_type, + /*input=*/new_filter); + rewriter.replaceAllUsesExcept(filter_op->getResult(0), dq.getOutput(), dq); + } + + private: + // Returns new quantized type for weights after transpose. + Type GetNewWeightQuantizedType(MLIRContext* context, Location location, + ArrayRef new_shape, Type filter_type, + bool is_depthwise) const { + auto tensor_type = filter_type.cast(); + auto element_type = tensor_type.getElementType(); + RankedTensorType new_filter_result_type; + if (element_type.isa()) { + auto per_axis_type = element_type.cast(); + int64_t kernel_output_feature_dim = + GetConvolutionKernelOutputFeatureDimension(is_depthwise); + auto new_filter_quantized_type = CreateI8F32UniformQuantizedPerAxisType( + location, *context, per_axis_type.getScales(), + per_axis_type.getZeroPoints(), + /*quantization_dimension=*/kernel_output_feature_dim, + /*narrow_range=*/true); + new_filter_result_type = + RankedTensorType::getChecked(location, + /*shape=*/new_shape, + /*type=*/new_filter_quantized_type); + } else if (element_type.isa()) { + auto per_tensor_type = element_type.cast(); + new_filter_result_type = + RankedTensorType::getChecked(location, + /*shape=*/new_shape, + /*type=*/per_tensor_type); + } else { + LLVM_DEBUG( + llvm::dbgs() + << "Weight tensor elements do not have uniform quantized type.\n"); + } + return new_filter_result_type; + } + + // Returns the dimension numbers of the given stablehlo's + // convolution attribute with transposed filter tensors to + // match TFLite format. + // Depthwise case (`feature_group_count` > 1) + // * `[0, 1, i, o]` -> `[i, 0, 1, o]` format. + // General convolution (`feature_group_count` = 1) + // * `[0, 1, i, o]` -> `[o, 0, 1, i]` format. + stablehlo::ConvDimensionNumbersAttr GetTflDimensionNumbers( + MLIRContext* context, + stablehlo::ConvDimensionNumbersAttr dimension_numbers, + bool is_depthwise) const { + int64_t kernel_input_feature_dim = + GetConvolutionKernelInputFeatureDimension(is_depthwise); + int64_t kernel_output_feature_dim = + GetConvolutionKernelOutputFeatureDimension(is_depthwise); + SmallVector kernel_spatial_dims{1, 2}; + + return stablehlo::ConvDimensionNumbersAttr::get( + context, dimension_numbers.getInputBatchDimension(), + dimension_numbers.getInputFeatureDimension(), + dimension_numbers.getInputSpatialDimensions(), kernel_input_feature_dim, + kernel_output_feature_dim, kernel_spatial_dims, + dimension_numbers.getOutputBatchDimension(), + dimension_numbers.getOutputFeatureDimension(), + dimension_numbers.getOutputSpatialDimensions()); + } +}; + void UniformQuantizedStableHloToTflPass::runOnOperation() { func::FuncOp func_op = getOperation(); MLIRContext& ctx = getContext(); RewritePatternSet patterns(&ctx); - patterns.add, - RewriteHybridQuantizedDotLikeOp, - RewriteUniformDequantizeOp, RewriteUniformQuantizeOp, - RewriteQuantizedAddOp, RewriteQuantizedBroadcastInDimOp, - RewriteQuantizedConcatenateOp, RewriteQuantizedConstantOp, - RewriteQuantizedConvolutionOp, + patterns.add, proj_clip = 0.01 : f32 } : (tensor<1x528xf32>, tensor<2048x528xf32>, tensor<2048x528xf32>, tensor<2048x528xf32>, tensor<2048x528xf32>, tensor<2048x640xf32>, tensor<2048x640xf32>, tensor<2048x640xf32>, tensor<2048x640xf32>, none, none, none, tensor<2048xf32>, tensor<2048xf32>, tensor<2048xf32>, tensor<2048xf32>, tensor<640x2048xf32>, tensor<640xf32>, tensor<1x640xf32>, tensor<1x2048xf32>, tensor<2048xf32>, tensor<2048xf32>, tensor<2048xf32>, tensor<2048xf32>) -> tensor<1x640xf32> func.return %0 : tensor<1x640xf32> -// CHECK: %[[NONE:.+]] = "tfl.no_value"() {value} : () -> none +// CHECK: %[[NONE:.+]] = "tfl.no_value"() <{value}> : () -> none // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %[[NONE]], %[[NONE]], %[[NONE]], %arg9, %arg10, %arg11, %arg12, %arg13, %[[NONE]], %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) } @@ -282,11 +282,11 @@ func.func @keepCustomFlexOps(%arg0: tensor<1x10xf32>) -> tensor<1x10xf32> { %2 = "tfl.custom"(%1, %arg0) {custom_code = "FlexAddV2", custom_option = #tfl} : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> "tfl.custom"(%0, %2) {custom_code = "FlexAssignVariableOp", custom_option = #tfl} : (tensor>>, tensor<1x10xf32>) -> () %3 = "tfl.custom"(%0) {custom_code = "FlexReadVariableOp", custom_option = #tfl} : (tensor>>) -> tensor<1x10xf32> - // CHECK: %0 = "tfl.custom"() {custom_code = "FlexVarHandleOp" - // CHECK-NEXT: %1 = "tfl.custom"(%0) {custom_code = "FlexReadVariableOp" - // CHECK-NEXT: %2 = "tfl.custom"(%1, %arg0) {custom_code = "FlexAddV2" - // CHECK-NEXT: "tfl.custom"(%0, %2) {custom_code = "FlexAssignVariableOp" - // CHECK-NEXT: %3 = "tfl.custom"(%0) {custom_code = "FlexReadVariableOp" + // CHECK: %0 = "tfl.custom"() <{custom_code = "FlexVarHandleOp" + // CHECK-NEXT: %1 = "tfl.custom"(%0) <{custom_code = "FlexReadVariableOp" + // CHECK-NEXT: %2 = "tfl.custom"(%1, %arg0) <{custom_code = "FlexAddV2" + // CHECK-NEXT: "tfl.custom"(%0, %2) <{custom_code = "FlexAssignVariableOp" + // CHECK-NEXT: %3 = "tfl.custom"(%0) <{custom_code = "FlexReadVariableOp" func.return %3 : tensor<1x10xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index f244d15294c253..9626a292b8eb6d 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -416,7 +416,7 @@ func.func @reshape_dynamic_output() -> tensor { %input = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> %shape = arith.constant dense<[4]> : tensor<1xi32> - // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() <{value = dense<[1, 2, 3, 4]> : tensor<4xi32>}> : () -> tensor // CHECK: return %[[CST]] %0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor func.return %0 : tensor @@ -438,7 +438,7 @@ func.func @range_int() -> tensor { %cst_1 = arith.constant dense<4> : tensor %cst_2 = arith.constant dense<1> : tensor - // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() <{value = dense<[0, 1, 2, 3]> : tensor<4xi32>}> : () -> tensor // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor func.return %0 : tensor @@ -450,7 +450,7 @@ func.func @range_float() -> tensor { %cst_1 = arith.constant dense<4.0> : tensor %cst_2 = arith.constant dense<1.0> : tensor - // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() <{value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>}> : () -> tensor // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor func.return %0 : tensor @@ -463,7 +463,7 @@ func.func @range_float_neg_delta() -> tensor { %cst_1 = arith.constant dense<-4.0> : tensor %cst_2 = arith.constant dense<-1.0> : tensor - // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() <{value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>}> : () -> tensor // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor func.return %0 : tensor @@ -475,7 +475,7 @@ func.func @range_float_nonzero_base() -> tensor { %cst_1 = arith.constant dense<7.0> : tensor %cst_2 = arith.constant dense<1.5> : tensor - // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() <{value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>}> : () -> tensor // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor func.return %0 : tensor @@ -508,7 +508,7 @@ func.func @transpose_dynamic() -> tensor { %cst = arith.constant dense<[1, 2, 3]> : tensor<3xi32> %cst_perm = arith.constant dense<0> : tensor<1xi32> - // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() <{value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>}> : () -> tensor // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor func.return %0 : tensor @@ -567,7 +567,7 @@ func.func @ConstantFoldBinaryOpDynamicOutput() -> tensor { %87 = "tfl.sub"(%cst_0, %cst) {fused_activation_function = "NONE"} : (tensor, tensor) -> tensor func.return %87 : tensor - // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() <{value = dense<[-5, 0]> : tensor<2xi32>}> : () -> tensor // CHECK: return %[[CST]] } @@ -580,7 +580,7 @@ func.func @add_dense_dense_int_same_shape_dynamic() -> tensor { func.return %2 : tensor - // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() <{value = dense<[5, 22, -2, 98]> : tensor<4xi32>}> : () -> tensor // CHECK: return %[[CST]] } @@ -603,7 +603,7 @@ func.func @concat_3_tensors_1_empty() -> tensor { %3 = "tfl.concatenation"(%0, %1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<0xi32>) -> tensor func.return %3 : tensor - // CHECK: %0 = "tfl.concatenation"(%[[CST]], %[[CST]]) {axis = 0 : i32, fused_activation_function = "NONE"} + // CHECK: %0 = "tfl.concatenation"(%[[CST]], %[[CST]]) <{axis = 0 : i32, fused_activation_function = "NONE"}> // CHECK: return %0 : tensor } @@ -835,7 +835,7 @@ func.func @NoFoldFullyConnectedNonFloat() -> tensor<1024xf32> { // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<512xf32> // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<2> : tensor<1024x512xi8> // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<4.000000e+00> : tensor<1024xf32> - // CHECK: %[[VAL:.*]] = "tfl.fully_connected"(%[[CST]], %[[CST_0]], %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<512xf32>, tensor<1024x512xi8>, tensor<1024xf32>) -> tensor<1024xf32> + // CHECK: %[[VAL:.*]] = "tfl.fully_connected"(%[[CST]], %[[CST_0]], %[[CST_1]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<512xf32>, tensor<1024x512xi8>, tensor<1024xf32>) -> tensor<1024xf32> // CHECK: return %[[VAL]] : tensor<1024xf32> } @@ -851,7 +851,7 @@ func.func @NoFoldFullyConnectedHighRank() -> tensor<2x1024xf32> { // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2x512xf32> // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<2.000000e+00> : tensor<1024x512xf32> // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<4.000000e+00> : tensor<1024xf32> - // CHECK: %[[VAL:.*]] = "tfl.fully_connected"(%[[CST]], %[[CST_0]], %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32> + // CHECK: %[[VAL:.*]] = "tfl.fully_connected"(%[[CST]], %[[CST_0]], %[[CST_1]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32> // CHECK: return %[[VAL]] : tensor<2x1024xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/decompose-hybrid-quantization.mlir b/tensorflow/compiler/mlir/lite/tests/decompose-hybrid-quantization.mlir index ce2c896ccde2ab..b656f8c649b5c4 100644 --- a/tensorflow/compiler/mlir/lite/tests/decompose-hybrid-quantization.mlir +++ b/tensorflow/compiler/mlir/lite/tests/decompose-hybrid-quantization.mlir @@ -2,9 +2,9 @@ // CHECK-LABEL: @test_conv2d_float func.func @test_conv2d_float(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x16xf32> { - // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_const"() {value = dense<42> : tensor<16x1x1x8xi8>} - // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_const"() {value = dense<1> : tensor<16x1x1x8xi8>} - // CHECK-DAG: %[[VAL2:.+]] = "tfl.conv_2d"(%arg0, %[[VAL0]], %[[VAL1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} + // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_const"() <{value = dense<42> : tensor<16x1x1x8xi8>}> + // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_const"() <{value = dense<1> : tensor<16x1x1x8xi8>}> + // CHECK-DAG: %[[VAL2:.+]] = "tfl.conv_2d"(%arg0, %[[VAL0]], %[[VAL1]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> // CHECK: return %[[VAL2]] %0 = "tfl.pseudo_const"() {value = dense<42> : tensor<16x1x1x8xi8>} : () -> tensor<16x1x1x8xf32> %1 = "tfl.pseudo_const"() {value = dense<1> : tensor<16x1x1x8xi8>} : () -> tensor<16xf32> @@ -16,9 +16,9 @@ func.func @test_conv2d_float(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x16x // CHECK-LABEL: @test_conv2d_qi8 func.func @test_conv2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<1x32x32x16x!quant.uniform> { - // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<{{.+}}>, value = dense<42> : tensor<16x1x1x8xi8>} - // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<{{.+}}>, value = dense<0> : tensor<16xi32>} - // CHECK-DAG: %[[VAL2:.+]] = "tfl.conv_2d"(%arg0, %0, %1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} + // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<{{.+}}>, value = dense<42> : tensor<16x1x1x8xi8>}> + // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<{{.+}}>, value = dense<0> : tensor<16xi32>}> + // CHECK-DAG: %[[VAL2:.+]] = "tfl.conv_2d"(%arg0, %0, %1) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> // CHECK: return %[[VAL2]] %0 = "tfl.pseudo_qconst"() {qtype = tensor<16x1x1x8x!quant.uniform>, value = dense<42> : tensor<16x1x1x8xi8>} : () -> tensor<16x1x1x8x!quant.uniform> %1 = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform>, value = dense<0> : tensor<16xi32>} : () -> tensor<16x!quant.uniform> @@ -31,8 +31,8 @@ func.func @test_conv2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform>) // CHECK-LABEL: @test_conv2d_qi16 func.func @test_conv2d_qi16(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<1x32x32x16x!quant.uniform> { // CHECK-DAG: %[[BIAS:.+]] = arith.constant dense<0> : tensor<16xi64> - // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<{{.+}}>, value = dense<42> : tensor<16x1x1x8xi8>} - // CHECK-DAG: %[[VAL1:.+]] = "tfl.conv_2d"(%arg0, %[[VAL0]], %[[BIAS]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} + // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<{{.+}}>, value = dense<42> : tensor<16x1x1x8xi8>}> + // CHECK-DAG: %[[VAL1:.+]] = "tfl.conv_2d"(%arg0, %[[VAL0]], %[[BIAS]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> // CHECK: return %[[VAL1]] %0 = "tfl.pseudo_qconst"() {qtype = tensor<16x1x1x8x!quant.uniform>, value = dense<42> : tensor<16x1x1x8xi8>} : () -> tensor<16x1x1x8x!quant.uniform> %1 = "arith.constant"() {value = dense<0> : tensor<16xi64>} : () -> tensor<16xi64> @@ -44,12 +44,12 @@ func.func @test_conv2d_qi16(%arg0: tensor<1x32x32x8x!quant.uniform // CHECK-LABEL: @test_conv2d_replace_qi8 func.func @test_conv2d_replace_qi8(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x16x!quant.uniform> { - // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<{{.+}}>, value = dense<42> : tensor<16x1x1x8xi8>} - // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<{{.+}}>, value = dense<0> : tensor<16xi32>} + // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<{{.+}}>, value = dense<42> : tensor<16x1x1x8xi8>}> + // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<{{.+}}>, value = dense<0> : tensor<16xi32>}> // CHECK-DAG: %[[VAL2:.+]] = "tfl.dequantize"(%[[VAL0]]) // CHECK-DAG: %[[VAL3:.+]] = "tfl.dequantize"(%[[VAL1]]) - // CHECK-DAG: %[[VAL4:.+]] = "tfl.conv_2d"(%arg0, %[[VAL2]], %[[VAL3]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} - // CHECK-DAG: %[[VAL5:.+]] = "tfl.quantize"(%4) {qtype = tensor<1x32x32x16x!quant.uniform>} + // CHECK-DAG: %[[VAL4:.+]] = "tfl.conv_2d"(%arg0, %[[VAL2]], %[[VAL3]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> + // CHECK-DAG: %[[VAL5:.+]] = "tfl.quantize"(%4) <{qtype = tensor<1x32x32x16x!quant.uniform>}> // CHECK: return %[[VAL5]] %0 = "tfl.pseudo_qconst"() {qtype = tensor<16x1x1x8x!quant.uniform>, value = dense<42> : tensor<16x1x1x8xi8>} : () -> tensor<16x1x1x8x!quant.uniform> %1 = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform>, value = dense<0> : tensor<16xi32>} : () -> tensor<16x!quant.uniform> @@ -61,11 +61,11 @@ func.func @test_conv2d_replace_qi8(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x // CHECK-LABEL: @test_conv2d_replace_float func.func @test_conv2d_replace_float(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x16xf32> { - // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<16x1x1x8x{{.+}}>, value = dense<42> : tensor<16x1x1x8xi8>} : () -> tensor<16x1x1x8x!quant.uniform<{{.+}}>> - // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<16x{{.+}}>, value = dense<0> : tensor<16xi32>} : () -> tensor<16x!quant.uniform<{{.+}}>> + // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<16x1x1x8x{{.+}}>, value = dense<42> : tensor<16x1x1x8xi8>}> : () -> tensor<16x1x1x8x!quant.uniform<{{.+}}>> + // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<16x{{.+}}>, value = dense<0> : tensor<16xi32>}> : () -> tensor<16x!quant.uniform<{{.+}}>> // CHECK-DAG: %[[VAL2:.+]] = "tfl.dequantize"(%[[VAL0]]) // CHECK-DAG: %[[VAL3:.+]] = "tfl.dequantize"(%[[VAL1]]) - // CHECK-DAG: %[[VAL4:.+]] = "tfl.conv_2d"(%arg0, %[[VAL2]], %[[VAL3]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} + // CHECK-DAG: %[[VAL4:.+]] = "tfl.conv_2d"(%arg0, %[[VAL2]], %[[VAL3]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> // CHECK: return %[[VAL4]] %0 = "tfl.pseudo_qconst"() {qtype = tensor<16x1x1x8x!quant.uniform>, value = dense<42> : tensor<16x1x1x8xi8>} : () -> tensor<16x1x1x8x!quant.uniform> %1 = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform>, value = dense<0> : tensor<16xi32>} : () -> tensor<16x!quant.uniform> @@ -77,10 +77,10 @@ func.func @test_conv2d_replace_float(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x3 // CHECK-LABEL: @test_conv3d_float func.func @test_conv3d_float(%arg0: tensor<1x32x32x32x8xf32>) -> tensor<1x32x32x32x16xf32> { - // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16xf32> - // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x1x8x16x!quant.uniform<{{.+}}>>, value = dense<42> : tensor<1x1x1x8x16xi8>} + // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_const"() <{value = dense<1.000000e+00> : tensor<16xf32> + // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x1x8x16x!quant.uniform<{{.+}}>>, value = dense<42> : tensor<1x1x1x8x16xi8>}> // CHECK: %[[VAL2:.+]] = "tfl.dequantize"(%[[VAL1]]) : (tensor<1x1x1x8x16x!quant.uniform<{{.+}}>>) -> tensor<1x1x1x8x16xf32> - // CHECK: %[[VAL3:.+]] = "tfl.conv_3d"(%arg0, %[[VAL2]], %[[VAL0]]) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} + // CHECK: %[[VAL3:.+]] = "tfl.conv_3d"(%arg0, %[[VAL2]], %[[VAL0]]) <{dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32}> // CHECK: return %[[VAL3]] %0 = "tfl.pseudo_qconst"() {qtype = tensor<1x1x8x16x!quant.uniform>, value = dense<42> : tensor<1x1x1x8x16xi8>} : () -> tensor<1x1x1x8x16x!quant.uniform> %1 = "tfl.pseudo_const"() { value = dense<1.0> : tensor<16xf32>} : () -> tensor<16xf32> @@ -92,12 +92,12 @@ func.func @test_conv3d_float(%arg0: tensor<1x32x32x32x8xf32>) -> tensor<1x32x32x // CHECK-LABEL: @test_transpose_conv2d func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x16xf32> { - // CHECK-DAG: %[[SHAPE:.+]] = "tfl.pseudo_const"() {value = dense<[1, 32, 32, 16]> - // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<16x{{.+}}>, value = dense<1> : tensor<16xi32>} - // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<16x{{.+}}>, value = dense<2> : tensor<16xi32>} + // CHECK-DAG: %[[SHAPE:.+]] = "tfl.pseudo_const"() <{value = dense<[1, 32, 32, 16]> + // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<16x{{.+}}>, value = dense<1> : tensor<16xi32>}> + // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<16x{{.+}}>, value = dense<2> : tensor<16xi32>}> // CHECK-DAG: %[[VAL2:.+]] = "tfl.dequantize"(%[[VAL0]]) // CHECK-DAG: %[[VAL3:.+]] = "tfl.dequantize"(%[[VAL1]]) - // CHECK-DAG: %[[VAL4:.+]] = "tfl.transpose_conv"(%[[SHAPE]], %[[VAL2]], %arg0, %[[VAL3]]) {fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} + // CHECK-DAG: %[[VAL4:.+]] = "tfl.transpose_conv"(%[[SHAPE]], %[[VAL2]], %arg0, %[[VAL3]]) <{fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> // CHECK: return %[[VAL4]] %0 = "tfl.pseudo_const"() { value = dense<[1, 32, 32, 16]> : tensor<4xi32> } : () -> tensor<4xi32> %1 = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform>, value = dense<1> : tensor<16xi32>} : () -> tensor<16x1x1x8x!quant.uniform> @@ -110,11 +110,11 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32 // CHECK-LABEL: @test_depthwise_conv2d_replace_float func.func @test_depthwise_conv2d_replace_float(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { - // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<{{.+}}>>, value = dense<42> : tensor<32x3x3x3xi8>} - // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<{{.+}}>>, value = dense<0> : tensor<32xi32>} + // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x3x3x3x!quant.uniform<{{.+}}>>, value = dense<42> : tensor<32x3x3x3xi8>}> + // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform<{{.+}}>>, value = dense<0> : tensor<32xi32>}> // CHECK-DAG: %[[VAL2:.+]] = "tfl.dequantize"(%[[VAL0]]) : (tensor<32x3x3x3x!quant.uniform<{{.+}}>>) // CHECK-DAG: %[[VAL3:.+]] = "tfl.dequantize"(%[[VAL1]]) : (tensor<32x!quant.uniform<{{.+}}) - // CHECK-DAG: %[[VAL4:.+]] = "tfl.depthwise_conv_2d"(%arg0, %[[VAL2]], %[[VAL3]]) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} + // CHECK-DAG: %[[VAL4:.+]] = "tfl.depthwise_conv_2d"(%arg0, %[[VAL2]], %[[VAL3]]) <{depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32}> // CHECK: return %[[VAL4]] %0 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform>, value = dense<42> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform> %1 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> @@ -126,11 +126,11 @@ func.func @test_depthwise_conv2d_replace_float(%arg0: tensor<1x224x224x3xf32>) - // CHECK-LABEL: @test_fullyconnected_replace_float func.func @test_fullyconnected_replace_float(%arg0: tensor<4x256x6x6xf32>) -> tensor<4x256x36xf32> { - // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<36x36x!quant.uniform<{{.+}}>>, value = dense<42> : tensor<36x36xi8>} - // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<36x!quant.uniform<{{.+}}>>, value = dense<0> : tensor<36xi32>} + // CHECK-DAG: %[[VAL0:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<36x36x!quant.uniform<{{.+}}>>, value = dense<42> : tensor<36x36xi8>}> + // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() <{qtype = tensor<36x!quant.uniform<{{.+}}>>, value = dense<0> : tensor<36xi32>}> // CHECK-DAG: %[[VAL2:.+]] = "tfl.dequantize"(%[[VAL0]]) : (tensor<36x36x!quant.uniform>) // CHECK-DAG: %[[VAL3:.+]] = "tfl.dequantize"(%[[VAL1]]) : (tensor<36x!quant.uniform>) - // CHECK: %[[VAL4:.+]] = "tfl.fully_connected"(%arg0, %[[VAL2]], %[[VAL3]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} + // CHECK: %[[VAL4:.+]] = "tfl.fully_connected"(%arg0, %[[VAL2]], %[[VAL3]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> // CHECK: return %[[VAL4]] %0 = "tfl.pseudo_qconst"() {qtype = tensor<36x36x!quant.uniform>, value = dense<42> : tensor<36x36xi8>} : () -> tensor<36x36x!quant.uniform> %1 = "tfl.pseudo_qconst"() {qtype = tensor<36x!quant.uniform>, value = dense<0> : tensor<36xi32>} : () -> tensor<36x!quant.uniform> diff --git a/tensorflow/compiler/mlir/lite/tests/default_quant_params.mlir b/tensorflow/compiler/mlir/lite/tests/default_quant_params.mlir index b062de1e84aede..2bb3020618d5c7 100644 --- a/tensorflow/compiler/mlir/lite/tests/default_quant_params.mlir +++ b/tensorflow/compiler/mlir/lite/tests/default_quant_params.mlir @@ -5,10 +5,10 @@ func.func @hardcode_all(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tenso %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32> func.return %0 : tensor<2x2xf32> -// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform>} -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform>} +// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<2x1x!quant.uniform>}> +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<2x2x!quant.uniform>}> // Quantized tfl.add -// CHECK: %[[add:.*]] = tfl.add(%[[q1]], %[[q0]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform> +// CHECK: %[[add:.*]] = tfl.add(%[[q1]], %[[q0]]) <{fused_activation_function = "NONE"}> : (tensor<2x2x!quant.uniform> // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform>) // CHECK: return %[[dq]] } @@ -20,9 +20,9 @@ func.func @hardcode_input(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> ten %4 = "tfl.add"(%1, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32> func.return %4 : tensor<2x2xf32> -// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform>} -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform>} -// CHECK: %[[add:.*]] = tfl.add(%[[q1]], %[[q0]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform> +// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<2x1x!quant.uniform>}> +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<2x2x!quant.uniform>}> +// CHECK: %[[add:.*]] = tfl.add(%[[q1]], %[[q0]]) <{fused_activation_function = "NONE"}> : (tensor<2x2x!quant.uniform> // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform>) // CHECK: return %[[dq]] } @@ -33,8 +33,8 @@ func.func @hardcode_input_deq(%arg0: tensor<2x2x!quant.uniform>, %a %4 = "tfl.add"(%1, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32> func.return %4 : tensor<2x2xf32> -// CHECK: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform>} -// CHECK: %[[add:.*]] = tfl.add(%arg0, %[[q]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform> +// CHECK: %[[q:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<2x1x!quant.uniform>}> +// CHECK: %[[add:.*]] = tfl.add(%arg0, %[[q]]) <{fused_activation_function = "NONE"}> : (tensor<2x2x!quant.uniform> // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform>) // CHECK: return %[[dq]] } @@ -48,9 +48,9 @@ func.func @hardcode_output(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> te %4 = "tfl.add"(%2, %3) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32> func.return %4 : tensor<2x2xf32> -// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform>} -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform>} -// CHECK: %[[add:.*]] = tfl.add(%[[q0]], %[[q1]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform> +// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<2x2x!quant.uniform>}> +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<2x1x!quant.uniform>}> +// CHECK: %[[add:.*]] = tfl.add(%[[q0]], %[[q1]]) <{fused_activation_function = "NONE"}> : (tensor<2x2x!quant.uniform> // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform>) // CHECK: return %[[dq]] } @@ -81,8 +81,8 @@ func.func @test_conv_2d_activation_and_bias(%arg0: tensor<1x224x224x3xf32>, %arg %1 = "tfl.conv_2d"(%arg0, %0, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> func.return %1 : tensor<1x112x112x32xf32> -// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg2) {qtype = tensor<32x!quant.uniform>} -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform>} +// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg2) <{qtype = tensor<32x!quant.uniform>}> +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x224x224x3x!quant.uniform>}> // CHECK: %[[conv:.*]] = "tfl.conv_2d"(%[[q1]], %arg1, %[[q0]]) // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[conv]]) : (tensor<1x112x112x32x!quant.uniform>) // CHECK: return %[[dq]] diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt index ada1c80dfd7535..3574ae83a41998 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt @@ -450,9 +450,9 @@ node { # MLIR-SAME: inputs = "input" # MLIR-SAME: outputs = "output" # MLIR: %[[shape:.*]] = arith.constant dense<[1, -1, 31]> : tensor<3xi32> -# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x!quant.uniform:f32:0, {0.12581039038230116, -# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%[[ARG_0]], %[[weight]], %[[bias]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} +# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<186x!quant.uniform:f32:0, {0.12581039038230116, +# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%[[ARG_0]], %[[weight]], %[[bias]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} # MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform>, tensor<3xi32>) # MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform> # MLIR: } diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel_4bit.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel_4bit.pbtxt index 60027cb443a091..7040bd424d02a8 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel_4bit.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel_4bit.pbtxt @@ -450,9 +450,9 @@ node { # MLIR-SAME: inputs = "input" # MLIR-SAME: outputs = "output" # MLIR: %[[shape:.*]] = arith.constant dense<[1, -1, 31]> : tensor<3xi32> -# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x!quant.uniform:f32:0, {2.2825599397931779, -# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%[[ARG_0]], %[[weight]], %[[bias]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} +# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<186x!quant.uniform:f32:0, {2.2825599397931779, +# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%[[ARG_0]], %[[weight]], %[[bias]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} # MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform>, tensor<3xi32>) # MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform> # MLIR: } diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_without_identity.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_without_identity.pbtxt index 6bacdbda2f933b..07c167fdfe01fc 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_without_identity.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_without_identity.pbtxt @@ -410,9 +410,9 @@ node { # MLIR-SAME: inputs = "input" # MLIR-SAME: outputs = "output" # MLIR: %[[shape:.*]] = arith.constant dense<[1, -1, 31]> : tensor<3xi32> -# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x!quant.uniform:f32:0, {0.12581039038230116, -# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%[[ARG_0]], %[[weight]], %[[bias]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} +# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<186x!quant.uniform:f32:0, {0.12581039038230116, +# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%[[ARG_0]], %[[weight]], %[[bias]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} # MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform>, tensor<3xi32>) # MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform> # MLIR: } diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_without_identity_4bit.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_without_identity_4bit.pbtxt index 12e02dbd014d1e..5e7f04d6beaaa4 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_without_identity_4bit.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_without_identity_4bit.pbtxt @@ -410,9 +410,9 @@ node { # MLIR-SAME: inputs = "input" # MLIR-SAME: outputs = "output" # MLIR: %[[shape:.*]] = arith.constant dense<[1, -1, 31]> : tensor<3xi32> -# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<186x!quant.uniform:f32:0, {2.2825599397931779, -# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%[[ARG_0]], %[[weight]], %[[bias]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} +# MLIR: %[[bias:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<186x!quant.uniform:f32:0, {2.2825599397931779, +# MLIR: %[[conv:.*]] = "tfl.conv_2d"(%[[ARG_0]], %[[weight]], %[[bias]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} # MLIR: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x1x1x186x!quant.uniform>, tensor<3xi32>) # MLIR: return %[[reshape]] : tensor<1x6x31x!quant.uniform> # MLIR: } diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/quant_stats.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/quant_stats.pbtxt index ebab9a55611287..a2f21223929d79 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/quant_stats.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/quant_stats.pbtxt @@ -58,7 +58,7 @@ versions { # MLIR-LABEL: func @main(%arg0: tensor<4x!quant.uniform>, %arg1: tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> # MLIR-SAME: attributes {tf.entry_function = {control_outputs = "", inputs = "input0,input1", outputs = "Add"}} { -# MLIR-NEXT: %[[add:.*]] = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<4x!quant.uniform>, tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> +# MLIR-NEXT: %[[add:.*]] = tfl.add(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<4x!quant.uniform>, tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> # MLIR-NEXT: return %[[add]] : tensor<4x!quant.uniform> # MLIR-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt index adbc83bde4b8fe..293fe283ee2685 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt @@ -79,12 +79,12 @@ versions { # CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} { # CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> -# CHECK-DAG: %[[VAL_3:.*]] = "tfl.no_value"() {value} : () -> none +# CHECK-DAG: %[[VAL_3:.*]] = "tfl.no_value"() <{value}> : () -> none # CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0> : tensor -# CHECK: %[[VAL_7:.*]]:2 = "tfl.split"(%[[VAL_6]], %[[VAL_0]]) {num_splits = 2 : i32} : (tensor, tensor<2x5x3xf32>) -> (tensor<1x5x3xf32>, tensor<1x5x3xf32>) +# CHECK: %[[VAL_7:.*]]:2 = "tfl.split"(%[[VAL_6]], %[[VAL_0]]) <{num_splits = 2 : i32}> : (tensor, tensor<2x5x3xf32>) -> (tensor<1x5x3xf32>, tensor<1x5x3xf32>) # CHECK: %[[VAL_9:.*]] = "tfl.transpose"(%[[VAL_1]], %[[VAL_2]]) : (tensor<3x7xf32>, tensor<2xi32>) -> tensor<7x3xf32> -# CHECK: %[[VAL_10:.*]] = "tfl.fully_connected"(%[[VAL_7]]#0, %[[VAL_9]], %[[VAL_3]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> -# CHECK: %[[VAL_11:.*]] = "tfl.fully_connected"(%[[VAL_7]]#1, %[[VAL_9]], %[[VAL_3]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> -# CHECK: %[[VAL_12:.*]] = "tfl.pack"(%[[VAL_10]], %[[VAL_11]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<2x5x7xf32> +# CHECK: %[[VAL_10:.*]] = "tfl.fully_connected"(%[[VAL_7]]#0, %[[VAL_9]], %[[VAL_3]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> +# CHECK: %[[VAL_11:.*]] = "tfl.fully_connected"(%[[VAL_7]]#1, %[[VAL_9]], %[[VAL_3]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> +# CHECK: %[[VAL_12:.*]] = "tfl.pack"(%[[VAL_10]], %[[VAL_11]]) <{axis = 0 : i32, values_count = 2 : i32}> : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<2x5x7xf32> # CHECK: return %[[VAL_12]] : tensor<2x5x7xf32> # CHECK: } diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul_disabled.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul_disabled.pbtxt index c9287bc6184fb3..b75f3076054978 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul_disabled.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul_disabled.pbtxt @@ -78,6 +78,6 @@ versions { } # CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} { -# CHECK: %[[VAL_2:.*]] = "tfl.batch_matmul"(%[[VAL_0]], %[[VAL_1]]) {adj_x = false, adj_y = false} : (tensor<2x5x3xf32>, tensor<3x7xf32>) -> tensor<2x5x7xf32> +# CHECK: %[[VAL_2:.*]] = "tfl.batch_matmul"(%[[VAL_0]], %[[VAL_1]]) <{adj_x = false, adj_y = false}> : (tensor<2x5x3xf32>, tensor<3x7xf32>) -> tensor<2x5x7xf32> # CHECK: return %[[VAL_2]] : tensor<2x5x7xf32> # CHECK: } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD index 2afbe2a0d2c766..5a5b9e32e8b9dd 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD @@ -31,6 +31,7 @@ filegroup( "//tensorflow/compiler/mlir/lite:flatbuffer_to_string", "//tensorflow/compiler/mlir/lite:flatbuffer_translate", "//tensorflow/compiler/mlir/lite:json_to_flatbuffer", + "//tensorflow/compiler/mlir/lite:tf_tfl_translate", "@llvm-project//llvm:FileCheck", ], ) @@ -55,8 +56,8 @@ tf_native_cc_binary( "importer_test_min_max.cc", ], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/lite:framework", - "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/basic_lstm.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/basic_lstm.mlir index af5f320a24f6e2..5f0845a391e783 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/basic_lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/basic_lstm.mlir @@ -3,7 +3,7 @@ func.func @main(%arg0: tensor<1x384xf32>, %arg1: tensor<1x96xf32>, %arg2: tensor<384x480xf32>, %arg3: tensor<384xf32>, %arg4: tensor<1x96xf32>) -> tensor<1x96xf32> { // CHECK-LABEL: @main -// CHECK: "tfl.basic_lstm"({{.*}}) {asymmetric_quantize_inputs = false, cell_clip = 1.000000e+00 : f32, fused_activation_function = "RELU", kernel_type = #tfl, proj_clip = 2.000000e+00 : f32} : (tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384xf32>, tensor<1x96xf32>) -> (tensor<1x96xf32>, tensor<1x96xf32>, tensor<1x480xf32>, tensor<1x384xf32>) +// CHECK: "tfl.basic_lstm"({{.*}}) <{cell_clip = 1.000000e+00 : f32, fused_activation_function = "RELU", kernel_type = #tfl, proj_clip = 2.000000e+00 : f32}> {asymmetric_quantize_inputs = false} : (tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384xf32>, tensor<1x96xf32>) -> (tensor<1x96xf32>, tensor<1x96xf32>, tensor<1x480xf32>, tensor<1x384xf32>) %0:4 = "tfl.basic_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "RELU", cell_clip = 1.0 : f32, proj_clip = 2.0 : f32} : (tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384xf32>, tensor<1x96xf32>) -> (tensor<1x96xf32>, tensor<1x96xf32>, tensor<1x480xf32>, tensor<1x384xf32>) func.return %0#0 : tensor<1x96xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/bucketize.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/bucketize.mlir index ba1bdcf9f28807..9b4463cda1a587 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/bucketize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/bucketize.mlir @@ -3,7 +3,7 @@ func.func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xi32> { // CHECK-LABEL: @main - // CHECK: "tfl.bucketize"(%arg0) {boundaries = [0.000000e+00 : f32, 1.000000e+01 : f32, 1.000000e+02 : f32]} : (tensor<3x2xf32>) -> tensor<3x2xi32> + // CHECK: "tfl.bucketize"(%arg0) <{boundaries = [0.000000e+00 : f32, 1.000000e+01 : f32, 1.000000e+02 : f32]}> : (tensor<3x2xf32>) -> tensor<3x2xi32> %0 = "tfl.bucketize"(%arg0) {boundaries = [0.0 : f32, 10.0 : f32, 100.0 : f32]} : (tensor<3x2xf32>) -> tensor<3x2xi32> func.return %0 : tensor<3x2xi32> } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/composite_op_round_trip.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/composite_op_round_trip.mlir new file mode 100644 index 00000000000000..2084e8b7fe004d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/composite_op_round_trip.mlir @@ -0,0 +1,27 @@ +// RUN: tf_tfl_translate --enable-hlo-to-tf-conversion --input-mlir %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s --check-prefix=CHECK-ROUNDTRIP + + +module { + func.func public @main( %arg0: tensor) -> tensor { + %0 = func.call @test_add_roundtrip(%arg0) : (tensor) -> tensor + + return %0 : tensor + } + + + // CHECK-LABEL: func.func private @test_add_roundtrip + func.func private @test_add_roundtrip(%arg0: tensor) -> tensor { + // CHECK-ROUNDTRIP: %0 = stablehlo.composite "stablehlo.add_n" %arg0 {composite_attributes = {test_bool = false, test_int = 2 : i64, test_string = "test"}, decomposition = @add_n.impl} : (tensor) -> tensor + %0 = stablehlo.composite "stablehlo.add_n" %arg0 { composite_attributes = { test_int = 2 : i64, test_bool = 0 : i1, test_string = "test"}, decomposition = @add_n.impl } : (tensor) -> tensor + return %0 : tensor + } + func.func private @add_n.impl(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2> : tensor + %1 = stablehlo.add %arg0, %0 : tensor + return %1 : tensor + } + + + + +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir index 6977b0e84c0d85..b5cc23bf58a739 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir @@ -102,42 +102,42 @@ func.func @int4() -> tensor<5xi4> { func.func @qi32_per_axis() -> tensor<3x3x!quant.uniform> { // CHECK-LABEL: @qi32_per_axis - // CHECK: {qtype = tensor<3x3x!quant.uniform>, value = dense<1> : tensor<3x3xi32>} : () -> tensor<3x3x!quant.uniform> + // CHECK: <{qtype = tensor<3x3x!quant.uniform>, value = dense<1> : tensor<3x3xi32>}> : () -> tensor<3x3x!quant.uniform> %0 = "tfl.pseudo_qconst"() { qtype = tensor<3x3x!quant.uniform>, value = dense<1> : tensor<3x3xi32>} : () -> tensor<3x3x!quant.uniform> func.return %0 : tensor<3x3x!quant.uniform> } func.func @qi32_per_axis_zero() -> tensor<3x3x!quant.uniform> { // CHECK-LABEL: @qi32_per_axis_zero - // CHECK: {qtype = tensor<3x3x!quant.uniform>, value = dense<1> : tensor<3x3xi32>} : () -> tensor<3x3x!quant.uniform> + // CHECK: <{qtype = tensor<3x3x!quant.uniform>, value = dense<1> : tensor<3x3xi32>}> : () -> tensor<3x3x!quant.uniform> %0 = "tfl.pseudo_qconst"() { qtype = tensor<3x3x!quant.uniform>, value = dense<1> : tensor<3x3xi32>} : () -> tensor<3x3x!quant.uniform> func.return %0 : tensor<3x3x!quant.uniform> } func.func @qu8() -> tensor<3x!quant.uniform:f32, 1.0>> { // CHECK-LABEL: @qu8 - // CHECK: {qtype = tensor<3x!quant.uniform:f32, 1.000000e+00>>, value = dense<1> : tensor<3xi8>} : () -> tensor<3x!quant.uniform:f32, 1.000000e+00>> + // CHECK: <{qtype = tensor<3x!quant.uniform:f32, 1.000000e+00>>, value = dense<1> : tensor<3xi8>}> : () -> tensor<3x!quant.uniform:f32, 1.000000e+00>> %0 = "tfl.pseudo_qconst"() { qtype = tensor<3x!quant.uniform:f32, 1.0>>, value = dense<1> : tensor<3xi8>} : () -> tensor<3x!quant.uniform:f32, 1.0>> func.return %0 : tensor<3x!quant.uniform:f32, 1.0>> } func.func @sparse_f32() -> tensor<3x2xf32> { // CHECK-LABEL: @sparse_f32 - // CHECK: {compressed_data = dense<[1.000000e+00, 2.000000e+00, 5.000000e-01, 2.500000e-01, -1.000000e+00, -2.000000e+00, -5.000000e-01, -2.500000e-01]> : tensor<8xf32>, s_param = #tfl.sparsity_parameter, , , >, value = dense<0.000000e+00> : tensor<3x2xf32>} + // CHECK: <{compressed_data = dense<[1.000000e+00, 2.000000e+00, 5.000000e-01, 2.500000e-01, -1.000000e+00, -2.000000e+00, -5.000000e-01, -2.500000e-01]> : tensor<8xf32>, s_param = #tfl.sparsity_parameter, , , >, value = dense<0.000000e+00> : tensor<3x2xf32>}> %0 = "tfl.pseudo_sparse_const"() {compressed_data = dense<[1.0, 2.0, 0.5, 0.25, -1.0, -2.0, -0.5, -0.25]> : tensor<8xf32>, s_param = #tfl.sparsity_parameter, #tfl.dimension_metadata, #tfl.dimension_metadata, #tfl.dimension_metadata>, value = dense<0.000000e+00> : tensor<3x2xf32>} : () -> tensor<3x2xf32> func.return %0: tensor<3x2xf32> } func.func @sparse_f16() -> tensor<3x2xf16> { // CHECK-LABEL: @sparse_f16 - // CHECK: {compressed_data = dense<[1.000000e+00, 2.000000e+00, 5.000000e-01, 2.500000e-01, -1.000000e+00, -2.000000e+00, -5.000000e-01, -2.500000e-01]> : tensor<8xf16>, s_param = #tfl.sparsity_parameter, , , >, value = dense<0.000000e+00> : tensor<3x2xf16>} + // CHECK: <{compressed_data = dense<[1.000000e+00, 2.000000e+00, 5.000000e-01, 2.500000e-01, -1.000000e+00, -2.000000e+00, -5.000000e-01, -2.500000e-01]> : tensor<8xf16>, s_param = #tfl.sparsity_parameter, , , >, value = dense<0.000000e+00> : tensor<3x2xf16>}> %0 = "tfl.pseudo_sparse_const"() {compressed_data = dense<[1.0, 2.0, 0.5, 0.25, -1.0, -2.0, -0.5, -0.25]> : tensor<8xf16>, s_param = #tfl.sparsity_parameter, #tfl.dimension_metadata, #tfl.dimension_metadata, #tfl.dimension_metadata>, value = dense<0.000000e+00> : tensor<3x2xf16>} : () -> tensor<3x2xf16> func.return %0: tensor<3x2xf16> } func.func @sparse_qu8() -> tensor<3x2x!quant.uniform:f32, 1.0>> { // CHECK-LABEL: @sparse_qu8 - // CHECK: {compressed_data = dense<[1, 2, 3, 4, -1, -2, -3, -4]> : tensor<8xi8>, qtype = tensor<3x2x!quant.uniform:f32, 1.000000e+00>>, s_param = #tfl.sparsity_parameter, , , >, value = dense<0> : tensor<3x2xi8>} + // CHECK: <{compressed_data = dense<[1, 2, 3, 4, -1, -2, -3, -4]> : tensor<8xi8>, qtype = tensor<3x2x!quant.uniform:f32, 1.000000e+00>>, s_param = #tfl.sparsity_parameter, , , >, value = dense<0> : tensor<3x2xi8>}> %0 = "tfl.pseudo_sparse_qconst"() {compressed_data = dense<[1, 2, 3, 4, -1, -2, -3, -4]> : tensor<8xi8>, qtype = tensor<3x2x!quant.uniform:f32, 1.0>>, s_param = #tfl.sparsity_parameter, #tfl.dimension_metadata, #tfl.dimension_metadata, #tfl.dimension_metadata>, value = dense<42> : tensor<3x2xi8>} : () -> tensor<3x2x!quant.uniform:f32, 1.0>> func.return %0: tensor<3x2x!quant.uniform:f32, 1.0>> } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants_offset.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants_offset.mlir index 1bbe28692c60eb..eeca24298432c2 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants_offset.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants_offset.mlir @@ -102,42 +102,42 @@ func.func @int4() -> tensor<5xi4> { func.func @qi32_per_axis() -> tensor<3x3x!quant.uniform> { // CHECK-LABEL: @qi32_per_axis - // CHECK: {qtype = tensor<3x3x!quant.uniform>, value = dense<1> : tensor<3x3xi32>} : () -> tensor<3x3x!quant.uniform> + // CHECK: <{qtype = tensor<3x3x!quant.uniform>, value = dense<1> : tensor<3x3xi32>}> : () -> tensor<3x3x!quant.uniform> %0 = "tfl.pseudo_qconst"() { qtype = tensor<3x3x!quant.uniform>, value = dense<1> : tensor<3x3xi32>} : () -> tensor<3x3x!quant.uniform> func.return %0 : tensor<3x3x!quant.uniform> } func.func @qi32_per_axis_zero() -> tensor<3x3x!quant.uniform> { // CHECK-LABEL: @qi32_per_axis_zero - // CHECK: {qtype = tensor<3x3x!quant.uniform>, value = dense<1> : tensor<3x3xi32>} : () -> tensor<3x3x!quant.uniform> + // CHECK: <{qtype = tensor<3x3x!quant.uniform>, value = dense<1> : tensor<3x3xi32>}> : () -> tensor<3x3x!quant.uniform> %0 = "tfl.pseudo_qconst"() { qtype = tensor<3x3x!quant.uniform>, value = dense<1> : tensor<3x3xi32>} : () -> tensor<3x3x!quant.uniform> func.return %0 : tensor<3x3x!quant.uniform> } func.func @qu8() -> tensor<3x!quant.uniform:f32, 1.0>> { // CHECK-LABEL: @qu8 - // CHECK: {qtype = tensor<3x!quant.uniform:f32, 1.000000e+00>>, value = dense<1> : tensor<3xi8>} : () -> tensor<3x!quant.uniform:f32, 1.000000e+00>> + // CHECK: <{qtype = tensor<3x!quant.uniform:f32, 1.000000e+00>>, value = dense<1> : tensor<3xi8>}> : () -> tensor<3x!quant.uniform:f32, 1.000000e+00>> %0 = "tfl.pseudo_qconst"() { qtype = tensor<3x!quant.uniform:f32, 1.0>>, value = dense<1> : tensor<3xi8>} : () -> tensor<3x!quant.uniform:f32, 1.0>> func.return %0 : tensor<3x!quant.uniform:f32, 1.0>> } func.func @sparse_f32() -> tensor<3x2xf32> { // CHECK-LABEL: @sparse_f32 - // CHECK: {compressed_data = dense<[1.000000e+00, 2.000000e+00, 5.000000e-01, 2.500000e-01, -1.000000e+00, -2.000000e+00, -5.000000e-01, -2.500000e-01]> : tensor<8xf32>, s_param = #tfl.sparsity_parameter, , , >, value = dense<0.000000e+00> : tensor<3x2xf32>} + // CHECK: <{compressed_data = dense<[1.000000e+00, 2.000000e+00, 5.000000e-01, 2.500000e-01, -1.000000e+00, -2.000000e+00, -5.000000e-01, -2.500000e-01]> : tensor<8xf32>, s_param = #tfl.sparsity_parameter, , , >, value = dense<0.000000e+00> : tensor<3x2xf32>}> %0 = "tfl.pseudo_sparse_const"() {compressed_data = dense<[1.0, 2.0, 0.5, 0.25, -1.0, -2.0, -0.5, -0.25]> : tensor<8xf32>, s_param = #tfl.sparsity_parameter, #tfl.dimension_metadata, #tfl.dimension_metadata, #tfl.dimension_metadata>, value = dense<0.000000e+00> : tensor<3x2xf32>} : () -> tensor<3x2xf32> func.return %0: tensor<3x2xf32> } func.func @sparse_f16() -> tensor<3x2xf16> { // CHECK-LABEL: @sparse_f16 - // CHECK: {compressed_data = dense<[1.000000e+00, 2.000000e+00, 5.000000e-01, 2.500000e-01, -1.000000e+00, -2.000000e+00, -5.000000e-01, -2.500000e-01]> : tensor<8xf16>, s_param = #tfl.sparsity_parameter, , , >, value = dense<0.000000e+00> : tensor<3x2xf16>} + // CHECK: <{compressed_data = dense<[1.000000e+00, 2.000000e+00, 5.000000e-01, 2.500000e-01, -1.000000e+00, -2.000000e+00, -5.000000e-01, -2.500000e-01]> : tensor<8xf16>, s_param = #tfl.sparsity_parameter, , , >, value = dense<0.000000e+00> : tensor<3x2xf16>}> %0 = "tfl.pseudo_sparse_const"() {compressed_data = dense<[1.0, 2.0, 0.5, 0.25, -1.0, -2.0, -0.5, -0.25]> : tensor<8xf16>, s_param = #tfl.sparsity_parameter, #tfl.dimension_metadata, #tfl.dimension_metadata, #tfl.dimension_metadata>, value = dense<0.000000e+00> : tensor<3x2xf16>} : () -> tensor<3x2xf16> func.return %0: tensor<3x2xf16> } func.func @sparse_qu8() -> tensor<3x2x!quant.uniform:f32, 1.0>> { // CHECK-LABEL: @sparse_qu8 - // CHECK: {compressed_data = dense<[1, 2, 3, 4, -1, -2, -3, -4]> : tensor<8xi8>, qtype = tensor<3x2x!quant.uniform:f32, 1.000000e+00>>, s_param = #tfl.sparsity_parameter, , , >, value = dense<0> : tensor<3x2xi8>} + // CHECK: <{compressed_data = dense<[1, 2, 3, 4, -1, -2, -3, -4]> : tensor<8xi8>, qtype = tensor<3x2x!quant.uniform:f32, 1.000000e+00>>, s_param = #tfl.sparsity_parameter, , , >, value = dense<0> : tensor<3x2xi8>}> %0 = "tfl.pseudo_sparse_qconst"() {compressed_data = dense<[1, 2, 3, 4, -1, -2, -3, -4]> : tensor<8xi8>, qtype = tensor<3x2x!quant.uniform:f32, 1.0>>, s_param = #tfl.sparsity_parameter, #tfl.dimension_metadata, #tfl.dimension_metadata, #tfl.dimension_metadata>, value = dense<42> : tensor<3x2xi8>} : () -> tensor<3x2x!quant.uniform:f32, 1.0>> func.return %0: tensor<3x2x!quant.uniform:f32, 1.0>> } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op.mlir index d798655017ee45..d1f7a4bb6423a2 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op.mlir @@ -5,4 +5,4 @@ func.func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, % func.return %0 : tensor<1x64x84x32xf32> } // CHECK-LABEL: main -// CHECK: "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "Convolution2DTransposeBias", custom_option = #tfl} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> +// CHECK: "tfl.custom"(%arg0, %arg1, %arg2) <{custom_code = "Convolution2DTransposeBias", custom_option = #tfl}> : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op_offset.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op_offset.mlir index e1f3f1ed9e09f8..9f93a628b5d6a4 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op_offset.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op_offset.mlir @@ -5,4 +5,4 @@ func.func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, % func.return %0 : tensor<1x64x84x32xf32> } // CHECK-LABEL: main -// CHECK: "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "Convolution2DTransposeBias", custom_option = #tfl} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> +// CHECK: "tfl.custom"(%arg0, %arg1, %arg2) <{custom_code = "Convolution2DTransposeBias", custom_option = #tfl}> : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/external_constant.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/external_constant.mlir index 81259b9a2e288f..377a8d45bb991a 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/external_constant.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/external_constant.mlir @@ -8,7 +8,7 @@ func.func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> { func.return %0 : tensor<40x40xf32> // CHECK-LABEL: func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> -// CHECK: %[[CONST:[0-9]+]] = "tfl.external_const"() {buffer_index = 3 : i32} : () -> tensor<40xf32> -// CHECK-NEXT: %[[FULL:[0-9]+]]:2 = "tfl.fully_connected"(%arg0, %arg1, %[[CONST]]) {asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} +// CHECK: %[[CONST:[0-9]+]] = "tfl.external_const"() <{buffer_index = 3 : i32}> : () -> tensor<40xf32> +// CHECK-NEXT: %[[FULL:[0-9]+]]:2 = "tfl.fully_connected"(%arg0, %arg1, %[[CONST]]) <{asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> // CHECK-NEXT: return %[[FULL]]#0 } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/import_json.json b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/import_json.json index c89592440d89ce..d00f1e6a58adcf 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/import_json.json +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/import_json.json @@ -1,7 +1,7 @@ // RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s -// CHECK: %[[CST:.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %[[CST]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32> +// CHECK: %[[CST:.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %[[CST]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32> // CHECK: return %[[RES0]] : tensor<256x32x32x16xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc index 8fc5a0cb051fdf..f4097c2e5e924b 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc @@ -24,8 +24,8 @@ limitations under the License. #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/raw_ostream.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/lite/model.h" -#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" using llvm::cl::opt; @@ -100,14 +100,14 @@ std::optional> InjectStatsToFullyConnected( // CHECK-LABEL: func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) // CHECK-SAME: -> tensor<40x40xf32> - // CHECK: %[[stat:.*]] = "quantfork.stats"(%arg0) {layerStats = dense< - // CHECK-SAME: [-1.000000e+00, 1.000000e+00]> : tensor<2xf32>} + // CHECK: %[[stat:.*]] = "quantfork.stats"(%arg0) <{layerStats = dense + // CHECK-SAME: <[-1.000000e+00, 1.000000e+00]> : tensor<2xf32>}> // CHECK-SAME: : (tensor<40x37xf32>) -> tensor<40x37xf32> - // CHECK-NEXT: %[[cst:.*]] = "tfl.pseudo_const"() {value = dense< - // CHECK-SAME: 1.000000e+00> : tensor<40xf32>} : () -> tensor<40xf32> + // CHECK-NEXT: %[[cst:.*]] = "tfl.pseudo_const"() <{value = dense< + // CHECK-SAME: 1.000000e+00> : tensor<40xf32>}> : () -> tensor<40xf32> // CHECK-NEXT: %[[fc:.*]]:2 = "tfl.fully_connected"(%[[stat]], %arg1, // CHECK-NEXT: %[[stat1:.*]] = "quantfork.stats"(%[[fc]]#0) - // CHECK-SAME: {axis = 1 : i64, + // CHECK-SAME: <{axis = 1 : i64, // CHECK-SAME: axisStats = dense<{{\[}}[-0.000000e+00, 0.000000e+00], // CHECK-SAME: [-1.000000e+00, 1.000000e+00], // CHECK-SAME: [-2.000000e+00, 2.000000e+00] diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc.mlir index 0f351b34f0fe7a..83e48bc1e102fa 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc.mlir @@ -8,7 +8,7 @@ func.func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> { func.return %0 : tensor<40x40xf32> // CHECK-LABEL: func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> -// CHECK: %[[CONST:[0-9]+]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<40xf32>} -// CHECK-NEXT: %[[FULL:[0-9]+]]:2 = "tfl.fully_connected"(%arg0, %arg1, %[[CONST]]) {asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} +// CHECK: %[[CONST:[0-9]+]] = "tfl.pseudo_const"() <{value = dense<1.000000e+00> : tensor<40xf32>}> +// CHECK-NEXT: %[[FULL:[0-9]+]]:2 = "tfl.fully_connected"(%arg0, %arg1, %[[CONST]]) <{asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> // CHECK-NEXT: return %[[FULL]]#0 } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/legacy_reshape.json b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/legacy_reshape.json index d698473713a325..c0d80c5a95247c 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/legacy_reshape.json +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/legacy_reshape.json @@ -1,6 +1,6 @@ // RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s -// CHECK: %0 = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<2> : tensor<2xi32>}> : () -> tensor<2xi32> // CHECK: %1 = "tfl.reshape"(%arg0, %0) : (tensor<1x4xf32>, tensor<2xi32>) -> tensor<2x2xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json index b6704814326ae2..9403af28efec36 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json @@ -1,8 +1,8 @@ // RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s // RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | flatbuffer_translate -mlir-to-tflite-flatbuffer - -o - | flatbuffer_to_string - | FileCheck --check-prefix=RoundTrip %s -// CHECK-DAG: %[[input_18:.*]] = "quantfork.stats"({{.*}}) {layerStats = dense<[-8.000000e-01, 1.600000e+00]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32> -// CHECK-DAG: %[[input_19:.*]] = "quantfork.stats"({{.*}}) {layerStats = dense<[-2.000000e+00, 4.000000e+00]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> +// CHECK-DAG: %[[input_18:.*]] = "quantfork.stats"({{.*}}) <{layerStats = dense<[-8.000000e-01, 1.600000e+00]> : tensor<2xf32>}> : (tensor<1x4xf32>) -> tensor<1x4xf32> +// CHECK-DAG: %[[input_19:.*]] = "quantfork.stats"({{.*}}) <{layerStats = dense<[-2.000000e+00, 4.000000e+00]> : tensor<2xf32>}> : (tensor<1x2xf32>) -> tensor<1x2xf32> // CHECK: "tfl.unidirectional_sequence_lstm"({{.*}}, %[[input_18]], %[[input_19]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) // CHECK-SAME: effective_hidden_scale_intermediate = tensor<*x!quant.calibrated>> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir index 33e5cca6e5de17..7b7b2d1273666b 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir @@ -8,10 +8,11 @@ func.func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x func.return %24 : tensor<1x4xf32> // CHECK-LABEL: main // separate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252 -// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> -// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> -// CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) ({ -// CHECK: }) {asymmetric_quantize_inputs = false, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> +// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() <{value = dense<{{.*}}> : tensor<1x4xf32>}> {tfl.is_variable} : () -> tensor<1x4xf32> +// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() <{value = dense<{{.*}}> : tensor<1x4xf32>}> {tfl.is_variable} : () -> tensor<1x4xf32> +// CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) +// CHECK-SAME: <{asymmetric_quantize_inputs = false, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32}> ({ +// CHECK: }) : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> // CHECK: return %[[RES2]] } @@ -23,16 +24,17 @@ func.func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform>, input_to_forget_intermediate = tensor<0x!quant.uniform>, input_to_cell_intermediate = tensor<0x!quant.uniform>, input_to_output_intermediate = tensor<0x!quant.uniform>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform>, kernel_type = #tfl, proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> func.return %0 : tensor<1x640x!quant.uniform> // CHECK-LABEL: testFullyQuantizedLSTM -// CHECK: %[[CST:.*]] = "tfl.no_value"() {value} : () -> none +// CHECK: %[[CST:.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %[[CST]], %[[CST]], %[[CST]], %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) -// CHECK: }) {asymmetric_quantize_inputs = false, cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform>, input_to_forget_intermediate = tensor<0x!quant.uniform>, input_to_input_intermediate = tensor<0x!quant.uniform>, input_to_output_intermediate = tensor<0x!quant.uniform>, kernel_type = #tfl, proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform>, tensor<2048x528x!quant.uniform>, tensor<2048x528x!quant.uniform>, tensor<2048x528x!quant.uniform>, tensor<2048x640x!quant.uniform>, tensor<2048x640x!quant.uniform>, tensor<2048x640x!quant.uniform>, tensor<2048x640x!quant.uniform>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> +// CHECK-SAME: <{asymmetric_quantize_inputs = false, cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform>, input_to_forget_intermediate = tensor<0x!quant.uniform>, input_to_input_intermediate = tensor<0x!quant.uniform>, input_to_output_intermediate = tensor<0x!quant.uniform>, kernel_type = #tfl, proj_clip = 0.00999999977 : f32}> ({ +// CHECK: }) : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform>, tensor<2048x528x!quant.uniform>, tensor<2048x528x!quant.uniform>, tensor<2048x528x!quant.uniform>, tensor<2048x640x!quant.uniform>, tensor<2048x640x!quant.uniform>, tensor<2048x640x!quant.uniform>, tensor<2048x640x!quant.uniform>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> } // ----- // CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates func.func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {asymmetric_quantize_inputs = false, cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, effective_hidden_scale_intermediate = tensor<0x!quant.uniform>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) <{asymmetric_quantize_inputs = false, cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, effective_hidden_scale_intermediate = tensor<0x!quant.uniform>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false}> : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -46,10 +48,11 @@ func.func @testLSTMAsymAttributeTrue(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {asymmetric_quantize_inputs = true, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> func.return %24 : tensor<1x4xf32> -// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> -// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> -// CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) ({ -// CHECK: }) {asymmetric_quantize_inputs = true, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> +// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() <{value = dense<{{.*}}> : tensor<1x4xf32>}> {tfl.is_variable} : () -> tensor<1x4xf32> +// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() <{value = dense<{{.*}}> : tensor<1x4xf32>}> {tfl.is_variable} : () -> tensor<1x4xf32> +// CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) +// CHECK-SAME: <{asymmetric_quantize_inputs = true, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32}> ({ +// CHECK: }) : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> // CHECK: return %[[RES2]] } @@ -63,10 +66,11 @@ func.func @testLSTMAsymAttributeFalse(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4x %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {asymmetric_quantize_inputs = false, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> func.return %24 : tensor<1x4xf32> -// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> -// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> -// CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) ({ -// CHECK: }) {asymmetric_quantize_inputs = false, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> +// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() <{value = dense<{{.*}}> : tensor<1x4xf32>}> {tfl.is_variable} : () -> tensor<1x4xf32> +// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() <{value = dense<{{.*}}> : tensor<1x4xf32>}> {tfl.is_variable} : () -> tensor<1x4xf32> +// CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) +// CHECK-SAME: <{asymmetric_quantize_inputs = false, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32}> ({ +// CHECK: }) : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> // CHECK: return %[[RES2]] } @@ -80,10 +84,11 @@ func.func @testLSTMAsymAttributeDefault(%arg0: tensor<1x4xf32>, %arg1: tensor<4x %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> func.return %24 : tensor<1x4xf32> -// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> -// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> -// CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) ({ -// CHECK: }) {asymmetric_quantize_inputs = false, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> +// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() <{value = dense<{{.*}}> : tensor<1x4xf32>}> {tfl.is_variable} : () -> tensor<1x4xf32> +// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() <{value = dense<{{.*}}> : tensor<1x4xf32>}> {tfl.is_variable} : () -> tensor<1x4xf32> +// CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) +// CHECK-SAME: <{asymmetric_quantize_inputs = false, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32}> ({ +// CHECK: }) : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> // CHECK: return %[[RES2]] } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/many_attribute_op.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/many_attribute_op.mlir index fb88f37615d731..faac23eedcff54 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/many_attribute_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/many_attribute_op.mlir @@ -3,7 +3,7 @@ // Confirm a wide array of attribute survives the round-trip func.func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> { ^bb0(%arg0: tensor<1x6x6x16xf32>): - // CHECK: "tfl.average_pool_2d"(%{{.*}}) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> + // CHECK: "tfl.average_pool_2d"(%{{.*}}) <{filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32}> : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> loc("avgpool") func.return %0 : tensor<1x1x1x16xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/math.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/math.mlir index 820bf2ca33decc..ac5c8a89725e29 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/math.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/math.mlir @@ -3,7 +3,7 @@ func.func @main(tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>): - // CHECK: [[CONST:%.*]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> + // CHECK: [[CONST:%.*]] = "tfl.pseudo_const"() <{value = dense<1.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32> // CHECK-NEXT: [[SQDIFF:%.*]] = tfl.squared_difference %arg0, [[CONST]] : tensor<4xf32> // CHECK-NEXT: %{{.*}} = tfl.mul %arg0, [[SQDIFF]] {fused_activation_function = "NONE"} : tensor<4xf32> %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/matmul.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/matmul.mlir index c786f08ef30209..0b817991f659b7 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/matmul.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/matmul.mlir @@ -5,7 +5,7 @@ func.func @main(%arg0: tensor<4x10x15xf32>, %arg1: tensor<4x15x17xf32>) -> tenso func.return %0: tensor<4x10x17xf32> // CHECK-LABEL: main -// CHECK: %[[RESULT0:.*]] = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<4x10x15xf32>, tensor<4x15x17xf32>) -> tensor<4x10x17xf32> +// CHECK: %[[RESULT0:.*]] = "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x10x15xf32>, tensor<4x15x17xf32>) -> tensor<4x10x17xf32> // CHECK: return %[[RESULT0]] } @@ -14,7 +14,7 @@ func.func @testMatmulAsymAttributeTrue(%arg0: tensor<4x10x15xf32>, %arg1: tensor %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = true} : (tensor<4x10x15xf32>, tensor<4x15x17xf32>) -> tensor<4x10x17xf32> func.return %0: tensor<4x10x17xf32> -// CHECK: %[[RESULT0:.*]] = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = true} : (tensor<4x10x15xf32>, tensor<4x15x17xf32>) -> tensor<4x10x17xf32> +// CHECK: %[[RESULT0:.*]] = "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = true}> : (tensor<4x10x15xf32>, tensor<4x15x17xf32>) -> tensor<4x10x17xf32> // CHECK: return %[[RESULT0]] } @@ -23,6 +23,6 @@ func.func @testMatmulAsymAttributeFalse(%arg0: tensor<4x10x15xf32>, %arg1: tenso %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<4x10x15xf32>, tensor<4x15x17xf32>) -> tensor<4x10x17xf32> func.return %0: tensor<4x10x17xf32> -// CHECK: %[[RESULT0:.*]] = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<4x10x15xf32>, tensor<4x15x17xf32>) -> tensor<4x10x17xf32> +// CHECK: %[[RESULT0:.*]] = "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x10x15xf32>, tensor<4x15x17xf32>) -> tensor<4x10x17xf32> // CHECK: return %[[RESULT0]] } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional.mlir index 2779c8ed684575..ada417fe9d4085 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional.mlir @@ -2,7 +2,7 @@ // Test to make sure optional parameters survive a roundtrip func.func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { -// CHECK: [[NONE:%.*]] = "tfl.no_value"() {value} : () -> none +// CHECK: [[NONE:%.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK: "tfl.fully_connected"(%arg0, %arg1, [[NONE]]) // CHECK-SAME: (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>, tensor<40x40xf32>) %cst = "tfl.no_value"() {value = unit} : () -> none diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional_input.json b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional_input.json index 3107e7ea2695c4..0d04ef052437e7 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional_input.json +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional_input.json @@ -2,8 +2,8 @@ // This test is to test that if the flatbuffer omits the last optional input `bias` of tfl.conv_2d op, the flatbuffer_importer will automatically adds `none` value to tfl.conv_2d. -// CHECK: %[[CST:.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %[[CST]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32> +// CHECK: %[[CST:.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %[[CST]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32> // CHECK: return %[[RES0]] : tensor<256x32x32x16xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/quant_stats.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/quant_stats.mlir index e0d581e4e799a9..8c7c1201841a70 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/quant_stats.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/quant_stats.mlir @@ -3,7 +3,7 @@ func.func @main(%arg0: tensor<1x512x672x8xf32>) -> tensor<1x512x672x8xf32> { // CHECK-LABEL: @main -// CHECK: %[[RES0:.*]] = "quantfork.stats"(%arg0) {layerStats = dense<[0.000000e+00, 2.550000e+02]> : tensor<2xf32>} : (tensor<1x512x672x8xf32>) -> tensor<1x512x672x8xf32> +// CHECK: %[[RES0:.*]] = "quantfork.stats"(%arg0) <{layerStats = dense<[0.000000e+00, 2.550000e+02]> : tensor<2xf32>}> : (tensor<1x512x672x8xf32>) -> tensor<1x512x672x8xf32> %0 = "quantfork.stats"(%arg0) {layerStats = dense<[0.000000e+00, 2.550000e+02]> : tensor<2xf32>} : (tensor<1x512x672x8xf32>) -> tensor<1x512x672x8xf32> func.return %0 : tensor<1x512x672x8xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/quantization.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/quantization.mlir index fdba4fd64b9697..6a6c3378155254 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/quantization.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/quantization.mlir @@ -2,10 +2,10 @@ // CHECK-LABEL: main func.func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x401408xf32> { -// CHECK: %{{.*}} = "tfl.quantize"(%{{.*}}) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> +// CHECK: %{{.*}} = "tfl.quantize"(%{{.*}}) <{qtype = tensor<1x224x224x3x!quant.uniform>}> : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> // The float values here doesn't match exactly because double -> float -> double is lossy -// CHECK-NEXT: %{{.*}} = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678{{[0-9]*}}:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678{{[0-9]*}}:151>> -// CHECK-NEXT: %{{.*}} = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> +// CHECK-NEXT: %{{.*}} = "tfl.pseudo_qconst"() <{qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678{{[0-9]*}}:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}> : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678{{[0-9]*}}:151>> +// CHECK-NEXT: %{{.*}} = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>}> : () -> tensor<32x!quant.uniform> // CHECK: %{{.*}} = "tfl.dequantize"(%{{.*}}) : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408xf32> %cst = arith.constant dense<[1, 401408]> : tensor<2xi32> @@ -27,9 +27,9 @@ func.func @quantized_constant(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { %3 = "tfl.dequantize"(%2) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> func.return %3 : tensor<2x2xf32> -// CHECK-NEXT: %[[Q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK-NEXT: %[[CST:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x!quant.uniform>, value = dense<-76> : tensor<1x2xi8>} : () -> tensor<1x2x!quant.uniform> -// CHECK-NEXT: %[[CONCAT:.*]] = "tfl.concatenation"(%[[Q]], %[[CST]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> +// CHECK-NEXT: %[[Q:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-NEXT: %[[CST:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x2x!quant.uniform>, value = dense<-76> : tensor<1x2xi8>}> : () -> tensor<1x2x!quant.uniform> +// CHECK-NEXT: %[[CONCAT:.*]] = "tfl.concatenation"(%[[Q]], %[[CST]]) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> // CHECK-NEXT: %[[DQ:.*]] = "tfl.dequantize"(%[[CONCAT]]) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> // CHECK-NEXT: return %[[DQ]] : tensor<2x2xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/reshape.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/reshape.mlir index 73344ddc4f535f..ae5f81c80b35f4 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/reshape.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/reshape.mlir @@ -2,7 +2,7 @@ // Confirm we can extract type info from reshape func.func @main() -> tensor<2x2xf32> { - // CHECK: %[[cst:.*]] = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: %[[cst:.*]] = "tfl.pseudo_const"() <{value = dense<2> : tensor<2xi32>}> : () -> tensor<2xi32> // CHECK: %{{.*}} = "tfl.reshape"(%{{.*}}, %[[cst]]) : (tensor<4xf32>, tensor<2xi32>) -> tensor<2x2xf32> %cst = arith.constant dense<[2, 2]> : tensor<2xi32> %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir index 906e6efff29305..0d642ce2c1a6e8 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir @@ -10,9 +10,9 @@ func.func @main(tensor<3x2xi32>) -> tensor<3x2xi32> { // CHECK-SAME: tfl.description = "MLIR Converted." // CHECK-SAME: tfl.schema_version = 3 : i32 - // CHECK: %{{.*}} = "tfl.pseudo_const"() {value = dense<{{\[\[1, 2\], \[3, 4\], \[5, 6\]\]}}> : tensor<3x2xi32>} + // CHECK: %{{.*}} = "tfl.pseudo_const"() <{value = dense<{{\[\[1, 2\], \[3, 4\], \[5, 6\]\]}}> : tensor<3x2xi32>}> // CHECK-NEXT: [[SUB:%.*]] = tfl.sub %{{.*}}, %{{.*}} {fused_activation_function = "RELU6"} : tensor<3x2xi32> - // CHECK-NEXT: [[SCALAR:%.*]] = "tfl.pseudo_const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NEXT: [[SCALAR:%.*]] = "tfl.pseudo_const"() <{value = dense<10> : tensor}> : () -> tensor // CHECK-NEXT: [[ADD:%.*]] = tfl.add([[SCALAR]], [[SUB]]) {fused_activation_function = "NONE"} : (tensor, tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-NEXT: return [[ADD]] : tensor<3x2xi32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/variable.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/variable.mlir index 0914fc37016771..40c8aa3ce64ddf 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/variable.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/variable.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: main func.func @main() -> tensor<3x2xi32> { - // CHECK: "tfl.pseudo_const"() {tfl.is_variable, value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + // CHECK: "tfl.pseudo_const"() <{value = dense<0> : tensor<3x2xi32>}> {tfl.is_variable} : () -> tensor<3x2xi32> %0 = "tfl.pseudo_const"() {value = dense<0> : tensor<3x2xi32>, tfl.is_variable} : () -> tensor<3x2xi32> loc("variable") func.return %0 : tensor<3x2xi32> } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir index 9b0e8c4863fc4b..51d685bd910561 100644 --- a/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir +++ b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir @@ -1027,7 +1027,7 @@ func.func @WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_23810(%ar } // CHECK: func private @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf_type.string> {tf._user_specified_name = "input"}) -> (tensor, tensor) attributes {tf._implements = #tf_type.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf_type.shape<1>], tf.signature.is_stateful} { -// CHECK: %0:2 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = #tfl} : (tensor<1x!tf_type.string>) -> (tensor, tensor) +// CHECK: %0:2 = "tfl.custom"(%arg0) <{custom_code = "tftext:WhitespaceTokenizer", custom_option = #tfl}> : (tensor<1x!tf_type.string>) -> (tensor, tensor) // CHECK: return %0#0, %0#1 : tensor, tensor func.func private @whitespace_tokenizer_rank2(%arg0: tensor {tf._user_specified_name = "input"}) -> (tensor, tensor, tensor) attributes {tf._input_shapes = [#tf_type.shape], tf._implements = #tf_type.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} { @@ -2161,7 +2161,7 @@ func.func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather // CHECK: func private @whitespace_tokenizer_rank2(%arg0: tensor {tf._user_specified_name = "input"}) -> (tensor, tensor, tensor) attributes {tf._implements = #tf_type.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf_type.shape], tf.signature.is_stateful} { -// CHECK: %0:3 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = #tfl} : (tensor) -> (tensor, tensor, tensor) +// CHECK: %0:3 = "tfl.custom"(%arg0) <{custom_code = "tftext:WhitespaceTokenizer", custom_option = #tfl}> : (tensor) -> (tensor, tensor, tensor) // CHECK: return %0#0, %0#1, %0#2 : tensor, tensor, tensor func.func private @whitespace_tokenizer_rank0(%arg0: tensor {tf._user_specified_name = "input"}) -> tensor attributes {tf._input_shapes = [#tf_type.shape<>], tf._implements = #tf_type.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} { @@ -3191,7 +3191,7 @@ func.func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertG } // CHECK: func private @whitespace_tokenizer_rank0(%arg0: tensor {tf._user_specified_name = "input"}) -> tensor attributes {tf._implements = #tf_type.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf_type.shape<>], tf.signature.is_stateful} { -// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = #tfl} : (tensor) -> tensor +// CHECK: %0 = "tfl.custom"(%arg0) <{custom_code = "tftext:WhitespaceTokenizer", custom_option = #tfl}> : (tensor) -> tensor // CHECK: return %0 : tensor func.func @ngrams(%arg0: tensor {tf._user_specified_name = "input"}) -> tensor attributes {tf._input_shapes = [#tf_type.shape], tf._implements = #tf_type.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = " ", width = 2 : i64}>} { @@ -3209,7 +3209,7 @@ func.func @ngrams(%arg0: tensor {tf._user_specified_name = "i } // CHECK: func @ngrams(%arg0: tensor {tf._user_specified_name = "input"}) -> tensor attributes {tf._implements = #tf_type.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = " ", width = 2 : i64}>, tf._input_shapes = [#tf_type.shape]} { -// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:Ngrams", custom_option = #tfl} : (tensor) -> tensor +// CHECK: %0 = "tfl.custom"(%arg0) <{custom_code = "tftext:Ngrams", custom_option = #tfl}> : (tensor) -> tensor // CHECK: return %0 : tensor // CHECK: } @@ -3434,7 +3434,7 @@ func.func private @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_Asser func.return %5 : tensor } // CHECK: func private @ngrams_ragged_rank_2(%arg0: tensor {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor {tf._user_specified_name = "args_1"}) -> (tensor, tensor<3xi64>, tensor) attributes {tf._implements = #tf_type.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf_type.shape, #tf_type.shape<3>, #tf_type.shape], tf.signature.is_stateful} { -// CHECK: %0:3 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "tftext:Ngrams", custom_option = #tfl} : (tensor, tensor<3xi64>, tensor) -> (tensor, tensor<3xi64>, tensor) +// CHECK: %0:3 = "tfl.custom"(%arg0, %arg1, %arg2) <{custom_code = "tftext:Ngrams", custom_option = #tfl}> : (tensor, tensor<3xi64>, tensor) -> (tensor, tensor<3xi64>, tensor) // CHECK: return %0#0, %0#1, %0#2 : tensor, tensor<3xi64>, tensor @@ -3449,5 +3449,5 @@ func.func private @sgnn_projection(%arg0: tensor {tf._user_sp // CHECK: func private @sgnn_projection(%arg0: tensor {tf._user_specified_name = "values"}, %arg1: tensor {tf._user_specified_name = "row_splits"}) -> tensor attributes {tf._implements = #tf_type.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf_type.shape, #tf_type.shape], tf.signature.is_stateful} { -// CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "tftext:custom:SgnnProjection", custom_option = #tfl} : (tensor, tensor) -> tensor +// CHECK: %0 = "tfl.custom"(%arg0, %arg1) <{custom_code = "tftext:custom:SgnnProjection", custom_option = #tfl}> : (tensor, tensor) -> tensor // CHECK: return %0 : tensor diff --git a/tensorflow/compiler/mlir/lite/tests/insert_call_once_op.mlir b/tensorflow/compiler/mlir/lite/tests/insert_call_once_op.mlir index db3d75e93a805a..44ce1b5c96e90d 100644 --- a/tensorflow/compiler/mlir/lite/tests/insert_call_once_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/insert_call_once_op.mlir @@ -22,7 +22,7 @@ module attributes {tf_saved_model.semantics} { %1 = "tf.LookupTableFindV2"(%0, %arg0, %cst) {device = ""} : (tensor, tensor, tensor) -> tensor<*x!tf_type.string> func.return %1 : tensor<*x!tf_type.string> // CHECK-LABEL: @serving_default - // CHECK: "tfl.call_once"() {session_init_function = "init_all_tables"} : () -> () + // CHECK: "tfl.call_once"() <{session_init_function = "init_all_tables"}> : () -> () } } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tensorlist.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tensorlist.mlir index 74b10665528406..2cf209c8e0444f 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tensorlist.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tensorlist.mlir @@ -4,7 +4,7 @@ // CHECK-LABEL: listReserveScalarShapeI32 func.func @listReserveScalarShapeI32(%arg0: tensor, %arg1: tensor) -> tensor>> { %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor, tensor) -> tensor>> - // CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "TensorListReserve", custom_option = #tfl} : (tensor, tensor) -> tensor>> + // CHECK: %0 = "tfl.custom"(%arg0, %arg1) <{custom_code = "TensorListReserve", custom_option = #tfl}> : (tensor, tensor) -> tensor>> func.return %0 : tensor>> } @@ -13,7 +13,7 @@ func.func @listReserveScalarShapeI32(%arg0: tensor, %arg1: tensor) -> // CHECK-LABEL: listReserve1DShapeI32 func.func @listReserve1DShapeI32(%arg0: tensor<2xi32>, %arg1: tensor) -> tensor>> { %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<2xi32>, tensor) -> tensor>> - // CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "TensorListReserve", custom_option = #tfl} : (tensor<2xi32>, tensor) -> tensor>> + // CHECK: %0 = "tfl.custom"(%arg0, %arg1) <{custom_code = "TensorListReserve", custom_option = #tfl}> : (tensor<2xi32>, tensor) -> tensor>> func.return %0 : tensor>> } @@ -22,7 +22,7 @@ func.func @listReserve1DShapeI32(%arg0: tensor<2xi32>, %arg1: tensor) -> te // CHECK-LABEL: listReserveScalarShapeFloat func.func @listReserveScalarShapeFloat(%arg0: tensor, %arg1: tensor) -> tensor>> { %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor, tensor) -> tensor>> - // CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "TensorListReserve", custom_option = #tfl} : (tensor, tensor) -> tensor>> + // CHECK: %0 = "tfl.custom"(%arg0, %arg1) <{custom_code = "TensorListReserve", custom_option = #tfl}> : (tensor, tensor) -> tensor>> func.return %0 : tensor>> } @@ -31,7 +31,7 @@ func.func @listReserveScalarShapeFloat(%arg0: tensor, %arg1: tensor) - // CHECK-LABEL: listReserveScalarShapeLong func.func @listReserveScalarShapeLong(%arg0: tensor, %arg1: tensor) -> tensor>> { %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor, tensor) -> tensor>> - // CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "TensorListReserve", custom_option = #tfl} : (tensor, tensor) -> tensor>> + // CHECK: %0 = "tfl.custom"(%arg0, %arg1) <{custom_code = "TensorListReserve", custom_option = #tfl}> : (tensor, tensor) -> tensor>> func.return %0 : tensor>> } @@ -40,7 +40,7 @@ func.func @listReserveScalarShapeLong(%arg0: tensor, %arg1: tensor) -> // CHECK-LABEL: listReserveScalarShapeBool func.func @listReserveScalarShapeBool(%arg0: tensor, %arg1: tensor) -> tensor>> { %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor, tensor) -> tensor>> - // CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "TensorListReserve", custom_option = #tfl} : (tensor, tensor) -> tensor>> + // CHECK: %0 = "tfl.custom"(%arg0, %arg1) <{custom_code = "TensorListReserve", custom_option = #tfl}> : (tensor, tensor) -> tensor>> func.return %0 : tensor>> } @@ -49,7 +49,7 @@ func.func @listReserveScalarShapeBool(%arg0: tensor, %arg1: tensor) -> // CHECK-LABEL: listStack func.func @listStack(%arg0: tensor>>, %arg1: tensor) -> tensor<*xi32> { %0 = "tf.TensorListStack"(%arg0, %arg1) : (tensor>>, tensor) -> tensor<*xi32> - // CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "TensorListStack", custom_option = #tfl} : (tensor>>, tensor) -> tensor<*xi32> + // CHECK: %0 = "tfl.custom"(%arg0, %arg1) <{custom_code = "TensorListStack", custom_option = #tfl}> : (tensor>>, tensor) -> tensor<*xi32> func.return %0 : tensor<*xi32> } @@ -58,7 +58,7 @@ func.func @listStack(%arg0: tensor>>, %arg1: tens // CHECK-LABEL: listSetItem func.func @listSetItem(%arg0: tensor>>, %arg1: tensor, %arg2: tensor<*xi32>) -> tensor>> { %0 = "tf.TensorListSetItem"(%arg0, %arg1, %arg2) : (tensor>>, tensor, tensor<*xi32>) -> tensor>> - // CHECK: %0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "TensorListSetItem", custom_option = #tfl} : (tensor>>, tensor, tensor<*xi32>) -> tensor>> + // CHECK: %0 = "tfl.custom"(%arg0, %arg1, %arg2) <{custom_code = "TensorListSetItem", custom_option = #tfl}> : (tensor>>, tensor, tensor<*xi32>) -> tensor>> func.return %0 : tensor>> } @@ -67,7 +67,7 @@ func.func @listSetItem(%arg0: tensor>>, %arg1: te // CHECK-LABEL: listGetItem func.func @listGetItem(%arg0: tensor>>, %arg1: tensor, %arg2: tensor<2xi32>) -> tensor<2xi32> { %0 = "tf.TensorListGetItem"(%arg0, %arg1, %arg2) : (tensor>>, tensor, tensor<2xi32>) -> tensor<2xi32> - // CHECK: %0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "TensorListGetItem", custom_option = #tfl} : (tensor>>, tensor, tensor<2xi32>) -> tensor<2xi32> + // CHECK: %0 = "tfl.custom"(%arg0, %arg1, %arg2) <{custom_code = "TensorListGetItem", custom_option = #tfl}> : (tensor>>, tensor, tensor<2xi32>) -> tensor<2xi32> func.return %0 : tensor<2xi32> } @@ -77,7 +77,7 @@ func.func @listGetItem(%arg0: tensor>>, %arg1: te func.func @listFromTensor(%tensor: tensor<3xi32>, %shape : tensor) -> tensor>> { %0 = "tf.TensorListFromTensor"(%tensor, %shape) : (tensor<3xi32>, tensor) -> tensor>> func.return %0 : tensor>> - // CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "TensorListFromTensor", custom_option = #tfl} : (tensor<3xi32>, tensor) -> tensor>> + // CHECK: %0 = "tfl.custom"(%arg0, %arg1) <{custom_code = "TensorListFromTensor", custom_option = #tfl}> : (tensor<3xi32>, tensor) -> tensor>> } // ----- @@ -95,7 +95,7 @@ func.func @typeNotSupportedNotLegalized(%arg0: tensor>>) -> tensor { %0 = "tf.TensorListLength"(%arg0) : (tensor>>) -> tensor - // CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "TensorListLength", custom_option = #tfl} : (tensor>>) -> tensor + // CHECK: %0 = "tfl.custom"(%arg0) <{custom_code = "TensorListLength", custom_option = #tfl}> : (tensor>>) -> tensor func.return %0 : tensor } @@ -105,7 +105,7 @@ func.func @listLength(%arg0: tensor>>) -> tensor< func.func @listEmptyToListReserve(%arg0: tensor, %arg1: tensor) -> tensor>> { %0 = "tf.EmptyTensorList"(%arg0, %arg1) : (tensor, tensor) -> tensor>> // CHECK: %cst = arith.constant dense<0> : tensor - // CHECK: %0 = "tfl.custom"(%arg0, %cst) {custom_code = "TensorListReserve", custom_option = #tfl} : (tensor, tensor) -> tensor>> + // CHECK: %0 = "tfl.custom"(%arg0, %cst) <{custom_code = "TensorListReserve", custom_option = #tfl}> : (tensor, tensor) -> tensor>> func.return %0 : tensor>> } @@ -114,7 +114,7 @@ func.func @listEmptyToListReserve(%arg0: tensor, %arg1: tensor) -> t // CHECK-LABEL: listElementShape func.func @listElementShape(%arg0: tensor>>) -> tensor<*xi32> { %0 = "tf.TensorListElementShape"(%arg0) : (tensor>>) -> tensor<*xi32> - // CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "TensorListElementShape", custom_option = #tfl} : (tensor>>) -> tensor<*xi32> + // CHECK: %0 = "tfl.custom"(%arg0) <{custom_code = "TensorListElementShape", custom_option = #tfl}> : (tensor>>) -> tensor<*xi32> func.return %0 : tensor<*xi32> } @@ -123,7 +123,7 @@ func.func @listElementShape(%arg0: tensor>>) -> t // CHECK-LABEL: listPopBack func.func @listPopBack(%arg0: tensor>>, %arg1: tensor<1xi32>) -> (tensor>>, tensor<2xi32>) { %0, %1 = "tf.TensorListPopBack"(%arg0, %arg1) : (tensor>>, tensor<1xi32>) -> (tensor>>, tensor<2xi32>) - // CHECK: %0:2 = "tfl.custom"(%arg0, %arg1) {custom_code = "TensorListPopBack", custom_option = #tfl} : (tensor>>, tensor<1xi32>) -> (tensor>>, tensor<2xi32>) + // CHECK: %0:2 = "tfl.custom"(%arg0, %arg1) <{custom_code = "TensorListPopBack", custom_option = #tfl}> : (tensor>>, tensor<1xi32>) -> (tensor>>, tensor<2xi32>) func.return %0, %1 : tensor>>, tensor<2xi32> } @@ -132,7 +132,7 @@ func.func @listPopBack(%arg0: tensor>>, %arg1: te // CHECK-LABEL: listPushBack func.func @listPushBack(%arg0: tensor>>, %arg1: tensor<16x1xf32>) -> tensor>> { %0 = "tf.TensorListPushBack"(%arg0, %arg1) : (tensor>>, tensor<16x1xf32>) -> tensor>> - // CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "TensorListPushBack", custom_option = #tfl} : (tensor>>, tensor<16x1xf32>) -> tensor>> + // CHECK: %0 = "tfl.custom"(%arg0, %arg1) <{custom_code = "TensorListPushBack", custom_option = #tfl}> : (tensor>>, tensor<16x1xf32>) -> tensor>> func.return %0: tensor>> } @@ -141,7 +141,7 @@ func.func @listPushBack(%arg0: tensor>>, %arg1: // CHECK-LABEL: variantAddN func.func @variantAddN(%arg0: tensor>>, %arg1: tensor>>) -> tensor>> { %1 = "tf.AddN"(%arg0, %arg1) : (tensor>>, tensor>>) -> tensor>> - // CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "VariantAddN", custom_option = #tfl} : (tensor>>, tensor>>) -> tensor>> + // CHECK: %0 = "tfl.custom"(%arg0, %arg1) <{custom_code = "VariantAddN", custom_option = #tfl}> : (tensor>>, tensor>>) -> tensor>> func.return %1 : tensor>> } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-hashtables.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-hashtables.mlir index 2e90bdff67ad6e..70d9701ac63b12 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-hashtables.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-hashtables.mlir @@ -7,7 +7,7 @@ func.func @hashtable_string_to_int64(%arg0: tensor) -> tensor<*xi64> { %1 = "tf.LookupTableFindV2"(%0, %cst, %arg0) {device = ""} : (tensor, tensor, tensor) -> tensor<*xi64> // CHECK-LABEL: hashtable_string_to_int64 // CHECK: [[CST:%.*]] = arith.constant dense<"f"> : tensor - // CHECK-NEXT: [[HASH_TABLE:%.*]] = "tfl.hashtable"() {key_dtype = !tf_type.string, table_id = 1530976467 : i32, value_dtype = i64} : () -> tensor<1x!tf_type.resource> + // CHECK-NEXT: [[HASH_TABLE:%.*]] = "tfl.hashtable"() <{key_dtype = !tf_type.string, table_id = 1530976467 : i32, value_dtype = i64}> : () -> tensor<1x!tf_type.resource> // CHECK-NEXT: [[FIND:%.*]] = "tfl.hashtable_find"([[HASH_TABLE]], [[CST]], %arg0) : (tensor<1x!tf_type.resource>, tensor, tensor) -> tensor<*xi64> // CHECK-NEXT: return [[FIND]] : tensor<*xi64> func.return %1 : tensor<*xi64> @@ -22,7 +22,7 @@ func.func @hashtable_int64_to_string(%arg0: tensor) -> tensor<*x!tf_type.st %1 = "tf.LookupTableFindV2"(%0, %arg0, %cst) {device = ""} : (tensor, tensor, tensor) -> tensor<*x!tf_type.string> // CHECK-LABEL: hashtable_int64_to_string // CHECK: [[CST:%.*]] = arith.constant dense<"f"> : tensor - // CHECK-NEXT: [[HASH_TABLE:%.*]] = "tfl.hashtable"() {key_dtype = i64, table_id = 1530976467 : i32, value_dtype = !tf_type.string} : () -> tensor<1x!tf_type.resource> + // CHECK-NEXT: [[HASH_TABLE:%.*]] = "tfl.hashtable"() <{key_dtype = i64, table_id = 1530976467 : i32, value_dtype = !tf_type.string}> : () -> tensor<1x!tf_type.resource> // CHECK-NEXT: [[FIND:%.*]] = "tfl.hashtable_find"([[HASH_TABLE]], %arg0, [[CST]]) : (tensor<1x!tf_type.resource>, tensor, tensor) -> tensor<*x!tf_type.string> // CHECK-NEXT: return [[FIND]] : tensor<*x!tf_type.string> func.return %1 : tensor<*x!tf_type.string> @@ -52,7 +52,7 @@ func.func @hashtable_import(%arg0: tensor<5x!tf_type.string>) { // CHECK-LABEL: hashtable_import // CHECK: [[CST:%.*]] = arith.constant dense<["emerson", "lake", "palmer"]> : tensor<3x!tf_type.string> // CHECK-NEXT: [[CST_0:%.*]] = arith.constant dense<[0, 1, 2]> : tensor<3xi64> - // CHECK-NEXT: [[HASH_TABLE:%.*]] = "tfl.hashtable"() {key_dtype = !tf_type.string, table_id = -1323619995 : i32, value_dtype = i64} : () -> tensor<1x!tf_type.resource> + // CHECK-NEXT: [[HASH_TABLE:%.*]] = "tfl.hashtable"() <{key_dtype = !tf_type.string, table_id = -1323619995 : i32, value_dtype = i64}> : () -> tensor<1x!tf_type.resource> // CHECK-NEXT: "tfl.hashtable_import"([[HASH_TABLE]], [[CST]], [[CST_0]]) : (tensor<1x!tf_type.resource>, tensor<3x!tf_type.string>, tensor<3xi64>) -> () } @@ -63,7 +63,7 @@ func.func @hashtable_size(%arg0: tensor<5x!tf_type.string>) -> tensor { %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_1dd4fef4-646d-491f-a3a8-bf5334f45813", use_node_name_sharing = false, value_dtype = i64} : () -> tensor %1 = "tf.LookupTableSizeV2"(%0) {device = ""} : (tensor) -> tensor // CHECK-LABEL: hashtable_size - // CHECK-NEXT: [[HASH_TABLE:%.*]] = "tfl.hashtable"() {key_dtype = !tf_type.string, table_id = -1323619995 : i32, value_dtype = i64} : () -> tensor<1x!tf_type.resource> + // CHECK-NEXT: [[HASH_TABLE:%.*]] = "tfl.hashtable"() <{key_dtype = !tf_type.string, table_id = -1323619995 : i32, value_dtype = i64}> : () -> tensor<1x!tf_type.resource> // CHECK-NEXT: [[SIZE:%.*]] = "tfl.hashtable_size"([[HASH_TABLE]]) : (tensor<1x!tf_type.resource>) -> tensor // CHECK-NEXT: return [[SIZE]] : tensor func.return %1 : tensor @@ -83,7 +83,7 @@ func.func @hashtable_import_then_find(%arg0: tensor<5x!tf_type.string>) -> tenso // CHECK: [[CST:%.*]] = arith.constant dense<["emerson", "lake", "palmer"]> : tensor<3x!tf_type.string> // CHECK-NEXT: [[CST_0:%.*]] = arith.constant dense<-1> : tensor // CHECK-NEXT: [[CST_1:%.*]] = arith.constant dense<[0, 1, 2]> : tensor<3xi64> - // CHECK-NEXT: [[HASH_TABLE:%.*]] = "tfl.hashtable"() {key_dtype = !tf_type.string, table_id = -1323619995 : i32, value_dtype = i64} : () -> tensor<1x!tf_type.resource> + // CHECK-NEXT: [[HASH_TABLE:%.*]] = "tfl.hashtable"() <{key_dtype = !tf_type.string, table_id = -1323619995 : i32, value_dtype = i64}> : () -> tensor<1x!tf_type.resource> // CHECK-NEXT: "tfl.hashtable_import"([[HASH_TABLE]], [[CST]], [[CST_1]]) : (tensor<1x!tf_type.resource>, tensor<3x!tf_type.string>, tensor<3xi64>) -> () // CHECK-NEXT: [[FIND:%.*]] = "tfl.hashtable_find"([[HASH_TABLE]], %arg0, [[CST_0]]) : (tensor<1x!tf_type.resource>, tensor<5x!tf_type.string>, tensor) -> tensor<*xi64> // CHECK-NEXT: return [[FIND]] : tensor<*xi64> @@ -102,7 +102,7 @@ func.func @hashtable_import_then_size(%arg0: tensor<5x!tf_type.string>) -> tenso // CHECK-LABEL: hashtable_import_then_size // CHECK: [[CST:%.*]] = arith.constant dense<["emerson", "lake", "palmer"]> : tensor<3x!tf_type.string> // CHECK-NEXT: [[CST_0:%.*]] = arith.constant dense<[0, 1, 2]> : tensor<3xi64> - // CHECK-NEXT: [[HASH_TABLE:%.*]] = "tfl.hashtable"() {key_dtype = !tf_type.string, table_id = -1323619995 : i32, value_dtype = i64} : () -> tensor<1x!tf_type.resource> + // CHECK-NEXT: [[HASH_TABLE:%.*]] = "tfl.hashtable"() <{key_dtype = !tf_type.string, table_id = -1323619995 : i32, value_dtype = i64}> : () -> tensor<1x!tf_type.resource> // CHECK-NEXT: "tfl.hashtable_import"([[HASH_TABLE]], [[CST]], [[CST_0]]) : (tensor<1x!tf_type.resource>, tensor<3x!tf_type.string>, tensor<3xi64>) -> () // CHECK-NEXT: [[SIZE:%.*]] = "tfl.hashtable_size"([[HASH_TABLE]]) : (tensor<1x!tf_type.resource>) -> tensor // CHECK-NEXT: return [[SIZE]] : tensor diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir index 46ff509b7cc46e..2c17e734c58dad 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir @@ -6,6 +6,6 @@ func.func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> ten // CHECK-LABEL: broadcast_to_bf16 // CHECK: [[CST:%.*]] = arith.constant dense<1.000000e+00> : tensor<3x3xbf16> -// CHECK: [[MUL:%.*]] = tfl.mul(%arg0, [[CST]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16> +// CHECK: [[MUL:%.*]] = tfl.mul(%arg0, [[CST]]) <{fused_activation_function = "NONE"}> : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16> // CHECK: return [[MUL]] : tensor<3x3xbf16> } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-variables.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-variables.mlir index 54a10bf0dad5bc..36b26c78b258f8 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-variables.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-variables.mlir @@ -15,8 +15,8 @@ module attributes {tf_saved_model.semantics} { func.return %2, %3 : tensor<1x10xf32>, tensor<1x10xi64> } - // CHECK: %[[RESOURCE:.*]] = "tfl.var_handle"() {container = "c", shared_name = "a"} : () -> tensor>> - // CHECK: %[[RESOURCE_1:.*]] = "tfl.var_handle"() {container = "c", shared_name = "b"} : () -> tensor>> + // CHECK: %[[RESOURCE:.*]] = "tfl.var_handle"() <{container = "c", shared_name = "a"}> : () -> tensor>> + // CHECK: %[[RESOURCE_1:.*]] = "tfl.var_handle"() <{container = "c", shared_name = "b"}> : () -> tensor>> // CHECK: %[[VAR_VAL:.*]] = "tfl.read_variable"(%[[RESOURCE]]) : (tensor>>) -> tensor<1x10xf32> // CHECK: %[[ADD:.*]] = tfl.add %[[VAR_VAL]], %arg0 {fused_activation_function = "NONE"} : tensor<1x10xf32> // CHECK: "tfl.assign_variable"(%[[RESOURCE]], %[[ADD]]) : (tensor>>, tensor<1x10xf32>) -> () @@ -41,7 +41,7 @@ module attributes {tf_saved_model.semantics} { "tf.AssignVariableOp"(%handle_0, %cst_1) : (tensor>>, tensor<1x10xf32>) -> () func.return // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<1x10xf32> - // CHECK: %[[RESOURCE:.*]] = "tfl.var_handle"() {container = "c", shared_name = "a"} : () -> tensor>> + // CHECK: %[[RESOURCE:.*]] = "tfl.var_handle"() <{container = "c", shared_name = "a"}> : () -> tensor>> // CHECK: "tfl.assign_variable"(%[[RESOURCE]], %[[CST]]) : (tensor>>, tensor<1x10xf32>) -> () } @@ -57,7 +57,7 @@ module attributes {tf_saved_model.semantics} { "tf.AssignVariableOp"(%handle_0, %1) : (tensor>>, tensor<1x10xf32>) -> () %2 = "tf.ReadVariableOp"(%handle_0) {device = ""} : (tensor>>) -> tensor<1x10xf32> func.return %2 : tensor<1x10xf32> - // CHECK: %[[RESOURCE:.*]] = "tfl.var_handle"() {container = "c", shared_name = "a"} : () -> tensor>> + // CHECK: %[[RESOURCE:.*]] = "tfl.var_handle"() <{container = "c", shared_name = "a"}> : () -> tensor>> // CHECK: %[[VAR_VAL:.*]] = "tfl.read_variable"(%[[RESOURCE]]) : (tensor>>) -> tensor<1x10xf32> // CHECK: %[[ADD:.*]] = tfl.add %[[VAR_VAL]], %arg0 {fused_activation_function = "NONE"} : tensor<1x10xf32> // CHECK: "tfl.assign_variable"(%[[RESOURCE]], %[[ADD]]) : (tensor>>, tensor<1x10xf32>) -> () diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir index ab9b39bd94cb97..0dfd9e1c1a78bf 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir @@ -58,7 +58,7 @@ func.func @while_cond_10_frozen0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %ar // CANON-SAME: (tensor, tensor<256x256xf32>, tensor) // CANON: [[VAL_1:%.*]] = arith.constant dense<1.000000e+00> : tensor<256x256xf32> // CANON: [[VAL_2:%.*]] = arith.constant dense<0> : tensor -// CANON: [[VAL_6:%.*]]:3 = "tfl.while"([[VAL_2]], [[VAL_2]], [[VAL_0]]) ({ +// CANON: [[VAL_6:%.*]]:3 = "tfl.while"([[VAL_2]], [[VAL_2]], [[VAL_0]]) <{is_stateless = true}> ({ // CANON: ^bb0([[VAL_7:%.*]]: tensor<*xi32>, [[VAL_8:%.*]]: tensor<*xi32>, [[VAL_9:%.*]]: tensor<*xf32>): // CANON: [[VAL_3:%.*]] = arith.constant dense<10> : tensor // CANON: [[VAL_10:%.*]] = "tf.Less"([[VAL_8]], [[VAL_3]]) @@ -71,6 +71,6 @@ func.func @while_cond_10_frozen0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %ar // CANON: [[VAL_15:%.*]] = "tf.AddV2"([[VAL_13]], [[VAL_5]]) // CANON: [[VAL_16:%.*]] = "tf.AddV2"([[VAL_11]], [[VAL_4]]) // CANON: "tfl.yield"([[VAL_16]], [[VAL_14]], [[VAL_15]]) : (tensor<*xi32>, tensor<*xi32>, tensor<*xf32>) -> () -// CANON: }) {is_stateless = true} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) +// CANON: }) : (tensor, tensor, tensor) -> (tensor, tensor, tensor) // CANON: return [[VAL_17:%.*]]#1, [[VAL_1]], [[VAL_17]]#2 : tensor, tensor<256x256xf32>, tensor // CANON: } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index a0b9f90a879507..2e178f754dbbc4 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -30,7 +30,7 @@ func.func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> { func.return %2: tensor<1xf32> // CHECK-LABEL: LeakyRelu -// CHECK: "tfl.leaky_relu"(%arg0) {alpha = 1.000000e-01 : f32} : (tensor<1xf32>) -> tensor<1xf32> +// CHECK: "tfl.leaky_relu"(%arg0) <{alpha = 1.000000e-01 : f32}> : (tensor<1xf32>) -> tensor<1xf32> } func.func @biasAdd(%arg0: tensor<1x10x10x32xf32>, %arg1: tensor<32xf32>) -> tensor<1x10x10x32xf32> { @@ -38,7 +38,7 @@ func.func @biasAdd(%arg0: tensor<1x10x10x32xf32>, %arg1: tensor<32xf32>) -> tens func.return %0 : tensor<1x10x10x32xf32> // CHECK-LABEL: biasAdd -// CHECK: tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32> +// CHECK: tfl.add(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32> } func.func @biasAddInt(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x10x10x32xi32> { @@ -57,8 +57,8 @@ func.func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor) %4 = "tf.some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32 func.return %4 : i32 // CHECK-LABEL: squeezeAndReshape -// CHECK: "tfl.squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32> -// CHECK: %1 = "tfl.squeeze"(%arg1) {squeeze_dims = []} : (tensor) -> tensor<*xf32> +// CHECK: "tfl.squeeze"(%arg0) <{squeeze_dims = [0]}> : (tensor<1x1x10xf32>) -> tensor<1x10xf32> +// CHECK: %1 = "tfl.squeeze"(%arg1) <{squeeze_dims = []}> : (tensor) -> tensor<*xf32> // CHECK: %cst = arith.constant dense<[2, 5]> : tensor<2xi32> // CHECK: %2 = "tfl.reshape"(%0, %cst) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32> // CHECK: %3 = "tf.some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32 @@ -118,7 +118,7 @@ func.func @avgPool2D(%arg0: tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> { func.return %6 : tensor<1x1x1x16xf32> // CHECK-LABEL: func @avgPool2D -// CHECK: "tfl.average_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> +// CHECK: "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32}> : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> // CHECK: %1 = "tf.AvgPool"(%arg0) // CHECK: %2 = "tf.AvgPool"(%arg0) } @@ -138,7 +138,7 @@ func.func @avgPool2DChannelFirst(%arg0: tensor<1x16x6x6xf32>) -> tensor<1x16x1x1 // CHECK-LABEL: func @avgPool2DChannelFirst // CHECK: %cst = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32> // CHECK: %0 = "tfl.transpose"(%arg0, %cst) : (tensor<1x16x6x6xf32>, tensor<4xi32>) -> tensor<1x6x6x16xf32> -// CHECK: %1 = "tfl.average_pool_2d"(%0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> +// CHECK: %1 = "tfl.average_pool_2d"(%0) <{filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32}> : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> // CHECK: %cst_0 = arith.constant dense<[0, 3, 1, 2]> : tensor<4xi32> // CHECK: %2 = "tfl.transpose"(%1, %cst_0) : (tensor<1x1x1x16xf32>, tensor<4xi32>) -> tensor<1x16x1x1xf32> // CHECK: %3 = "tf.AvgPool"(%arg0) @@ -150,7 +150,7 @@ func.func @softmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { func.return %0 : tensor<8x16xf32> // CHECK-LABEL: softmax -// CHECK: "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x16xf32>) -> tensor<8x16xf32> +// CHECK: "tfl.softmax"(%arg0) <{beta = 1.000000e+00 : f32}> : (tensor<8x16xf32>) -> tensor<8x16xf32> } func.func @softplus(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { @@ -160,7 +160,7 @@ func.func @softplus(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { // CHECK-LABEL: softplus // CHECK: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor -// CHECK: %[[add:.*]] = tfl.add(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> +// CHECK: %[[add:.*]] = tfl.add(%[[exp]], %[[cst]]) <{fused_activation_function = "NONE"}> : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> // CHECK: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32> } @@ -169,7 +169,7 @@ func.func @fakeQuantArgsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> func.return %0 : tensor<8x8x8x8xf32> // CHECK-LABEL: fakeQuantArgsFalse - // CHECK: "tfl.quantize"(%arg0) {qtype = tensor<8x8x8x8x!quant.uniform>} + // CHECK: "tfl.quantize"(%arg0) <{qtype = tensor<8x8x8x8x!quant.uniform>}> // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<8x8x8x8x!quant.uniform>) -> tensor<8x8x8x8xf32> } @@ -178,7 +178,7 @@ func.func @fakeQuantArgsTrue(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> func.return %0 : tensor<8x8x8x8xf32> // CHECK-LABEL: fakeQuantArgsTrue - // CHECK: "tfl.quantize"(%arg0) {qtype = tensor<8x8x8x8x!quant.uniform:f32, 0.001181102379804521:86>>} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8x!quant.uniform:f32, 0.001181102379804521:86>> + // CHECK: "tfl.quantize"(%arg0) <{qtype = tensor<8x8x8x8x!quant.uniform:f32, 0.001181102379804521:86>>}> : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8x!quant.uniform:f32, 0.001181102379804521:86>> // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<8x8x8x8x!quant.uniform:f32, 0.001181102379804521:86>>) -> tensor<8x8x8x8xf32> } @@ -189,7 +189,7 @@ func.func @fakeQuantVarsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> func.return %0 : tensor<8x8x8x8xf32> // CHECK-LABEL: fakeQuantVarsFalse - // CHECK: "tfl.quantize"(%arg0) {qtype = tensor<8x8x8x8x!quant.uniform>} + // CHECK: "tfl.quantize"(%arg0) <{qtype = tensor<8x8x8x8x!quant.uniform>}> // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<8x8x8x8x!quant.uniform>) -> tensor<8x8x8x8xf32> } @@ -206,7 +206,7 @@ func.func @fakeQuantArgsFalse4Bits(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8 func.return %0 : tensor<8x8x8x8xf32> // CHECK-LABEL: fakeQuantArgsFalse - // CHECK: "tfl.quantize"(%arg0) {qtype = tensor<8x8x8x8x!quant.uniform>} + // CHECK: "tfl.quantize"(%arg0) <{qtype = tensor<8x8x8x8x!quant.uniform>}> // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<8x8x8x8x!quant.uniform>) -> tensor<8x8x8x8xf32> } @@ -215,7 +215,7 @@ func.func @fakeQuantArgsTrue4Bits(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8x func.return %0 : tensor<8x8x8x8xf32> // CHECK-LABEL: fakeQuantArgsTrue - // CHECK: "tfl.quantize"(%arg0) {qtype = tensor<8x8x8x8x!quant.uniform:f32, 0.021428571747882024:6>>} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8x!quant.uniform:f32, 0.021428571747882024:6>> + // CHECK: "tfl.quantize"(%arg0) <{qtype = tensor<8x8x8x8x!quant.uniform:f32, 0.021428571747882024:6>>}> : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8x!quant.uniform:f32, 0.021428571747882024:6>> // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<8x8x8x8x!quant.uniform:f32, 0.021428571747882024:6>>) -> tensor<8x8x8x8xf32> } @@ -226,7 +226,7 @@ func.func @fakeQuantVarsFalse4Bits(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8 func.return %0 : tensor<8x8x8x8xf32> // CHECK-LABEL: fakeQuantVarsFalse - // CHECK: "tfl.quantize"(%arg0) {qtype = tensor<8x8x8x8x!quant.uniform>} + // CHECK: "tfl.quantize"(%arg0) <{qtype = tensor<8x8x8x8x!quant.uniform>}> // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<8x8x8x8x!quant.uniform>) -> tensor<8x8x8x8xf32> } @@ -243,7 +243,7 @@ func.func @const() -> tensor<2xi32> { func.return %0: tensor<2xi32> // CHECK-LABEL: @const -// CHECK: "tfl.pseudo_const"() {value = #tf_type : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: "tfl.pseudo_const"() <{value = #tf_type : tensor<2xi32>}> : () -> tensor<2xi32> } func.func @shape(%arg0: tensor) -> tensor<2xi32> { @@ -357,7 +357,7 @@ func.func @maxPool2D(%arg0: tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32> { func.return %6 : tensor<1x1x1x16xf32> // CHECK-LABEL: func @maxPool2D -// CHECK: "tfl.max_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32> +// CHECK: "tfl.max_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32}> : (tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32> // CHECK: %1 = "tf.MaxPool"(%arg0) // CHECK: %2 = "tf.MaxPool"(%arg0) } @@ -379,7 +379,7 @@ func.func @maxPool2DChannelFirst(%arg0: tensor<1x16x6x6xf32>) -> tensor<1x16x1x1 // CHECK-LABEL: func @maxPool2DChannelFirst // CHECK: %cst = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32> // CHECK: %0 = "tfl.transpose"(%arg0, %cst) : (tensor<1x16x6x6xf32>, tensor<4xi32>) -> tensor<1x6x6x16xf32> -// CHECK: %1 = "tfl.max_pool_2d"(%0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> +// CHECK: %1 = "tfl.max_pool_2d"(%0) <{filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32}> : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> // CHECK: %cst_0 = arith.constant dense<[0, 3, 1, 2]> : tensor<4xi32> // CHECK: %2 = "tfl.transpose"(%1, %cst_0) : (tensor<1x1x1x16xf32>, tensor<4xi32>) -> tensor<1x16x1x1xf32> // CHECK: %3 = "tf.MaxPool"(%arg0) @@ -399,7 +399,7 @@ func.func @any(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { func.return %0 : tensor // CHECK-LABEL:any -// CHECK: "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor +// CHECK: "tfl.reduce_any"(%arg0, %arg1) <{keep_dims = false}> : (tensor<2x2xi1>, tensor) -> tensor } func.func @any_i64axes(%arg0: tensor<8x16x16xi1>, %arg1: tensor<2xi64>) -> tensor { @@ -408,7 +408,7 @@ func.func @any_i64axes(%arg0: tensor<8x16x16xi1>, %arg1: tensor<2xi64>) -> tenso // CHECK-LABEL: any_i64axes // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32> - // CHECK: "tfl.reduce_any"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi32>) -> tensor + // CHECK: "tfl.reduce_any"(%arg0, %[[V0]]) <{keep_dims = false}> : (tensor<8x16x16xi1>, tensor<2xi32>) -> tensor } func.func @ceil(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { @@ -449,7 +449,7 @@ func.func @squeezeDefault(%arg0: tensor<1x2x2xf32>) -> tensor<2x2xf32> { func.return %0 : tensor<2x2xf32> // CHECK-LABEL:squeezeDefault -// CHECK: "tfl.squeeze"(%arg0) {squeeze_dims = []} : (tensor<1x2x2xf32>) -> tensor<2x2xf32> +// CHECK: "tfl.squeeze"(%arg0) <{squeeze_dims = []}> : (tensor<1x2x2xf32>) -> tensor<2x2xf32> } func.func @squeezeSingleAxis(%arg0: tensor<2x1x2xf32>) -> tensor<2x2xf32> { @@ -457,7 +457,7 @@ func.func @squeezeSingleAxis(%arg0: tensor<2x1x2xf32>) -> tensor<2x2xf32> { func.return %0 : tensor<2x2xf32> // CHECK-LABEL:squeezeSingleAxis -// CHECK: "tfl.squeeze"(%arg0) {squeeze_dims = [1]} : (tensor<2x1x2xf32>) -> tensor<2x2xf32> +// CHECK: "tfl.squeeze"(%arg0) <{squeeze_dims = [1]}> : (tensor<2x1x2xf32>) -> tensor<2x2xf32> } func.func @squeezeTwoAxes(%arg0: tensor<1x2x1x2xf32>) -> tensor<2x2xf32> { @@ -465,7 +465,7 @@ func.func @squeezeTwoAxes(%arg0: tensor<1x2x1x2xf32>) -> tensor<2x2xf32> { func.return %0 : tensor<2x2xf32> // CHECK-LABEL:squeezeTwoAxes -// CHECK: "tfl.squeeze"(%arg0) {squeeze_dims = [0, 2]} : (tensor<1x2x1x2xf32>) -> tensor<2x2xf32> +// CHECK: "tfl.squeeze"(%arg0) <{squeeze_dims = [0, 2]}> : (tensor<1x2x1x2xf32>) -> tensor<2x2xf32> } func.func @gatherScalarIndices(%arg0 : tensor<3x2xf32>, %arg1 : tensor) -> tensor<2xf32> { @@ -473,7 +473,7 @@ func.func @gatherScalarIndices(%arg0 : tensor<3x2xf32>, %arg1 : tensor) -> func.return %0 : tensor<2xf32> // CHECK-LABEL:gatherScalarIndices -// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<3x2xf32>, tensor) -> tensor<2xf32> +// CHECK: "tfl.gather"(%arg0, %arg1) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3x2xf32>, tensor) -> tensor<2xf32> } func.func @gatherVectorIndices(%arg0 : tensor<2xf32>, %arg1 : tensor<3xi32>) -> tensor<3xf32> { @@ -481,7 +481,7 @@ func.func @gatherVectorIndices(%arg0 : tensor<2xf32>, %arg1 : tensor<3xi32>) -> func.return %0 : tensor<3xf32> // CHECK-LABEL:gatherVectorIndices -// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<2xf32>, tensor<3xi32>) -> tensor<3xf32> +// CHECK: "tfl.gather"(%arg0, %arg1) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<2xf32>, tensor<3xi32>) -> tensor<3xf32> } func.func @gatherHigherRankIndices(%arg0 : tensor<2x3x6xf32>, %arg1 : tensor<4x5xi32>) -> tensor<4x5x3x6xf32> { @@ -489,7 +489,7 @@ func.func @gatherHigherRankIndices(%arg0 : tensor<2x3x6xf32>, %arg1 : tensor<4x5 func.return %0 : tensor<4x5x3x6xf32> // CHECK-LABEL:gatherHigherRankIndices -// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<2x3x6xf32>, tensor<4x5xi32>) -> tensor<4x5x3x6xf32> +// CHECK: "tfl.gather"(%arg0, %arg1) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<2x3x6xf32>, tensor<4x5xi32>) -> tensor<4x5x3x6xf32> } func.func @gatherNdVectorIndices(%arg0 : tensor<3x2x2xf32>, %arg1 : tensor<2xi32>) -> tensor<2xf32> { @@ -544,7 +544,7 @@ func.func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5x func.return %1 : tensor<1x3x5x20xf32> // CHECK-LABEL:gatherV2VectorIndices -// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32, batch_dims = 0 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32> +// CHECK: "tfl.gather"(%arg0, %arg1) <{axis = 1 : i32, batch_dims = 0 : i32}> : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32> } func.func @gatherV2VectorIndices_I64Axis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> { @@ -553,7 +553,7 @@ func.func @gatherV2VectorIndices_I64Axis(%arg0 : tensor<1x2x20xf32>, %arg1 : ten func.return %1 : tensor<1x3x5x20xf32> // CHECK-LABEL:gatherV2VectorIndices_I64Axis -// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32, batch_dims = 0 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32> +// CHECK: "tfl.gather"(%arg0, %arg1) <{axis = 1 : i32, batch_dims = 0 : i32}> : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32> } func.func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> { @@ -562,7 +562,7 @@ func.func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tens func.return %1 : tensor<1x2x3x5xf32> // CHECK-LABEL:gatherV2VectorIndices -// CHECK: "tfl.gather"(%arg0, %arg1) {axis = -1 : i32, batch_dims = 0 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x2x3x5xf32> +// CHECK: "tfl.gather"(%arg0, %arg1) <{axis = -1 : i32, batch_dims = 0 : i32}> : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x2x3x5xf32> } func.func @gatherWithBatchDims(%arg0 : tensor<2x3x6xf32>, %arg1 : tensor<2x5xi32>) -> tensor<2x5x3x6xf32> { @@ -571,7 +571,7 @@ func.func @gatherWithBatchDims(%arg0 : tensor<2x3x6xf32>, %arg1 : tensor<2x5xi32 func.return %1 : tensor<2x5x3x6xf32> // CHECK-LABEL:gatherWithBatchDims -// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32, batch_dims = 1 : i32} : (tensor<2x3x6xf32>, tensor<2x5xi32>) -> tensor<2x5x3x6xf32> +// CHECK: "tfl.gather"(%arg0, %arg1) <{axis = 1 : i32, batch_dims = 1 : i32}> : (tensor<2x3x6xf32>, tensor<2x5xi32>) -> tensor<2x5x3x6xf32> } @@ -1037,7 +1037,7 @@ func.func @pack2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x func.return %0 : tensor<2x2xi32> // CHECK-LABEL: pack2Tensors -// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> +// CHECK: "tfl.pack"(%arg0, %arg1) <{axis = 0 : i32, values_count = 2 : i32}> : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> } func.func @pack3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2xi32>) -> tensor<2x3xi32> { @@ -1045,7 +1045,7 @@ func.func @pack3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tens func.return %0 : tensor<2x3xi32> // CHECK-LABEL: pack3Tensors -// CHECK: "tfl.pack"(%arg0, %arg1, %arg2) {axis = 1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32> +// CHECK: "tfl.pack"(%arg0, %arg1, %arg2) <{axis = 1 : i32, values_count = 3 : i32}> : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32> } func.func @packStringWithFlex(%arg0: tensor<2x!tf_type.string>, %arg1: tensor<2x!tf_type.string>) -> tensor<2x2x!tf_type.string> { @@ -1061,7 +1061,7 @@ func.func @packNegAxis(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tenso func.return %0 : tensor<2x3xi32> // CHECK-LABEL: packNegAxis -// CHECK: "tfl.pack"(%arg0, %arg1, %arg2) {axis = -1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32> +// CHECK: "tfl.pack"(%arg0, %arg1, %arg2) <{axis = -1 : i32, values_count = 3 : i32}> : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32> } func.func @unpack2Tensors(%arg0: tensor<2x2xi32>) -> tensor<2xi32> { @@ -1069,7 +1069,7 @@ func.func @unpack2Tensors(%arg0: tensor<2x2xi32>) -> tensor<2xi32> { func.return %0#0 : tensor<2xi32> // CHECK-LABEL: unpack2Tensors -// CHECK: "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x2xi32>) -> (tensor<2xi32>, tensor<2xi32>) +// CHECK: "tfl.unpack"(%arg0) <{axis = 0 : i32, num = 2 : i32}> : (tensor<2x2xi32>) -> (tensor<2xi32>, tensor<2xi32>) } func.func @unpack3Tensors(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { @@ -1077,7 +1077,7 @@ func.func @unpack3Tensors(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { func.return %0#0 : tensor<2xi32> // CHECK-LABEL: unpack3Tensors -// CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) +// CHECK: "tfl.unpack"(%arg0) <{axis = 1 : i32, num = 3 : i32}> : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) } func.func @unpackNegAxis(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { @@ -1085,7 +1085,7 @@ func.func @unpackNegAxis(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { func.return %0#0 : tensor<2xi32> // CHECK-LABEL: unpackNegAxis -// CHECK: "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) +// CHECK: "tfl.unpack"(%arg0) <{axis = -1 : i32, num = 3 : i32}> : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) } func.func @mean(%arg0: tensor<2x2xf32>, %arg1: tensor<1xi32>) -> tensor<1x2xf32> { @@ -1093,7 +1093,7 @@ func.func @mean(%arg0: tensor<2x2xf32>, %arg1: tensor<1xi32>) -> tensor<1x2xf32> func.return %0 : tensor<1x2xf32> // CHECK-LABEL: mean -// CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32> +// CHECK: "tfl.mean"(%arg0, %arg1) <{keep_dims = false}> : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32> } func.func @mean_true(%arg0: tensor<2x2xf32>, %arg1: tensor<1xi32>) -> tensor<1x2xf32> { @@ -1101,7 +1101,7 @@ func.func @mean_true(%arg0: tensor<2x2xf32>, %arg1: tensor<1xi32>) -> tensor<1x2 func.return %0 : tensor<1x2xf32> // CHECK-LABEL: mean_true -// CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = true} : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32> +// CHECK: "tfl.mean"(%arg0, %arg1) <{keep_dims = true}> : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32> } func.func @sum(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor { @@ -1109,7 +1109,7 @@ func.func @sum(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor // CHECK-LABEL: sum - // CHECK: "tfl.sum"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor + // CHECK: "tfl.sum"(%arg0, %arg1) <{keep_dims = false}> : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor } func.func @sum_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor { @@ -1117,7 +1117,7 @@ func.func @sum_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor< func.return %0 : tensor // CHECK-LABEL: sum_true - // CHECK: "tfl.sum"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor + // CHECK: "tfl.sum"(%arg0, %arg1) <{keep_dims = true}> : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor } func.func @sum_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor { @@ -1126,7 +1126,7 @@ func.func @sum_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tens // CHECK-LABEL: sum_i64axes // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32> - // CHECK: "tfl.sum"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor + // CHECK: "tfl.sum"(%arg0, %[[V0]]) <{keep_dims = false}> : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor } func.func @reduce_min(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor { @@ -1134,7 +1134,7 @@ func.func @reduce_min(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tenso func.return %0 : tensor // CHECK-LABEL: reduce_min - // CHECK: "tfl.reduce_min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor + // CHECK: "tfl.reduce_min"(%arg0, %arg1) <{keep_dims = false}> : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor } func.func @reduce_min_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor { @@ -1142,7 +1142,7 @@ func.func @reduce_min_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> func.return %0 : tensor // CHECK-LABEL: reduce_min_true - // CHECK: "tfl.reduce_min"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor + // CHECK: "tfl.reduce_min"(%arg0, %arg1) <{keep_dims = true}> : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor } func.func @reduce_min_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor { @@ -1151,7 +1151,7 @@ func.func @reduce_min_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) // CHECK-LABEL: reduce_min_i64axes // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32> - // CHECK: "tfl.reduce_min"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor + // CHECK: "tfl.reduce_min"(%arg0, %[[V0]]) <{keep_dims = false}> : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor } func.func @reduce_max(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor { @@ -1159,7 +1159,7 @@ func.func @reduce_max(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tenso func.return %0 : tensor // CHECK-LABEL: reduce_max - // CHECK: "tfl.reduce_max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor + // CHECK: "tfl.reduce_max"(%arg0, %arg1) <{keep_dims = false}> : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor } func.func @reduce_max_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor { @@ -1167,7 +1167,7 @@ func.func @reduce_max_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> func.return %0 : tensor // CHECK-LABEL: reduce_max_true - // CHECK: "tfl.reduce_max"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor + // CHECK: "tfl.reduce_max"(%arg0, %arg1) <{keep_dims = true}> : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor } func.func @reduce_max_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor { @@ -1176,7 +1176,7 @@ func.func @reduce_max_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) // CHECK-LABEL: reduce_max_i64axes // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32> - // CHECK: "tfl.reduce_max"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor + // CHECK: "tfl.reduce_max"(%arg0, %[[V0]]) <{keep_dims = false}> : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor } func.func @reduce_prod(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor { @@ -1184,7 +1184,7 @@ func.func @reduce_prod(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tens func.return %0 : tensor // CHECK-LABEL: reduce_prod - // CHECK: "tfl.reduce_prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor + // CHECK: "tfl.reduce_prod"(%arg0, %arg1) <{keep_dims = false}> : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor } func.func @reduce_prod_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor { @@ -1192,7 +1192,7 @@ func.func @reduce_prod_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> func.return %0 : tensor // CHECK-LABEL: reduce_prod_true - // CHECK: "tfl.reduce_prod"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor + // CHECK: "tfl.reduce_prod"(%arg0, %arg1) <{keep_dims = true}> : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor } func.func @reduce_prod_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor { @@ -1201,7 +1201,7 @@ func.func @reduce_prod_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) // CHECK-LABEL: reduce_prod_i64axes // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32> - // CHECK: "tfl.reduce_prod"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor + // CHECK: "tfl.reduce_prod"(%arg0, %[[V0]]) <{keep_dims = false}> : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor } func.func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor { @@ -1248,7 +1248,7 @@ func.func @split(%arg0: tensor, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3 func.return %0#0 : tensor<1x4x3xf32> // CHECK-LABEL: split - // CHECK: "tfl.split"(%arg0, %arg1) {num_splits = 3 : i32} : (tensor, tensor<1x4x3x3xf32>) -> (tensor<1x4x3xf32>, tensor<1x4x3xf32>, tensor<1x4x3xf32>) + // CHECK: "tfl.split"(%arg0, %arg1) <{num_splits = 3 : i32}> : (tensor, tensor<1x4x3x3xf32>) -> (tensor<1x4x3xf32>, tensor<1x4x3xf32>, tensor<1x4x3xf32>) } func.func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor) -> tensor<1x4x2x3xf32> { @@ -1256,7 +1256,7 @@ func.func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tenso func.return %0#0 : tensor<1x4x2x3xf32> // CHECK-LABEL: splitv - // CHECK: "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 2 : i32} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>) + // CHECK: "tfl.split_v"(%arg0, %arg1, %arg2) <{num_splits = 2 : i32}> : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>) } func.func @matmul(%arg0: tensor<40x37xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> { @@ -1266,8 +1266,8 @@ func.func @matmul(%arg0: tensor<40x37xf32>, %arg1: tensor<37x40xf32>) -> tensor< // CHECK-LABEL: matmul // CHECK: %[[CST:.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> // CHECK: %[[ARG:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> -// CHECK: %[[CST_0:.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: "tfl.fully_connected"(%arg0, %[[ARG]], %[[CST_0]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +// CHECK: %[[CST_0:.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: "tfl.fully_connected"(%arg0, %[[ARG]], %[[CST_0]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> } func.func @matmul_transposed_a(%arg0: tensor<37x40xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> { @@ -1278,8 +1278,8 @@ func.func @matmul_transposed_a(%arg0: tensor<37x40xf32>, %arg1: tensor<37x40xf32 // CHECK: %[[CST_0:.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> // CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> // CHECK: %[[ARG_1:.*]] = "tfl.transpose"(%arg1, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> -// CHECK: %[[CST_2:.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: "tfl.fully_connected"(%[[ARG_0]], %[[ARG_1]], %[[CST_2]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +// CHECK: %[[CST_2:.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: "tfl.fully_connected"(%[[ARG_0]], %[[ARG_1]], %[[CST_2]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> } func.func @matmul_transposed_b(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { @@ -1287,8 +1287,8 @@ func.func @matmul_transposed_b(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32 (tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> func.return %0 : tensor<40x40xf32> // CHECK-LABEL: matmul_transposed_b -// CHECK: %[[CST:.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: "tfl.fully_connected"(%arg0, %arg1, %[[CST]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +// CHECK: %[[CST:.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: "tfl.fully_connected"(%arg0, %arg1, %[[CST]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> } func.func @matmul_transposed_ab(%arg0: tensor<37x40xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { @@ -1298,8 +1298,8 @@ func.func @matmul_transposed_ab(%arg0: tensor<37x40xf32>, %arg1: tensor<40x37xf3 // CHECK-LABEL: matmul_transposed_ab // CHECK: %[[CST_0:.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> // CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> -// CHECK: %[[CST_1:.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: "tfl.fully_connected"(%[[ARG_0]], %arg1, %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +// CHECK: %[[CST_1:.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: "tfl.fully_connected"(%[[ARG_0]], %arg1, %[[CST_1]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> } func.func @concat_v2_with_3_tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { @@ -1308,7 +1308,7 @@ func.func @concat_v2_with_3_tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi3 func.return %1 : tensor<2x3xi32> // CHECK-LABEL: concat_v2_with_3_tensors -// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> +// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> } func.func @concat_v2_i64_axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { @@ -1317,7 +1317,7 @@ func.func @concat_v2_i64_axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %a func.return %1 : tensor<2x3xi32> // CHECK-LABEL: concat_v2_i64_axis -// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> +// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> } func.func @concat_v2_with_bool_type(%arg0: tensor, %arg1: tensor) -> tensor { @@ -1326,28 +1326,28 @@ func.func @concat_v2_with_bool_type(%arg0: tensor, %arg1: tensor func.return %1 : tensor // CHECK-LABEL: concat_v2_with_bool_type -// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor, tensor) -> tensor +// CHECK: "tfl.concatenation"(%arg0, %arg1) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor, tensor) -> tensor } func.func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor { %0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor func.return %0 : tensor // CHECK-LABEL: resize_with_bilinear - // CHECK: "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = true, half_pixel_centers = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor + // CHECK: "tfl.resize_bilinear"(%arg0, %arg1) <{align_corners = true, half_pixel_centers = false}> : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor } func.func @resize_with_bilinear_with_half_pixel_centers(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor { %0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = false, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor func.return %0 : tensor // CHECK-LABEL: resize_with_bilinear_with_half_pixel_centers - // CHECK: "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor + // CHECK: "tfl.resize_bilinear"(%arg0, %arg1) <{align_corners = false, half_pixel_centers = true}> : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor } func.func @strided_slice(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xf32> { %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64, offset = false} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> func.return %0 : tensor<1x2x2x5xf32> // CHECK-LABEL: strided_slice - // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> } func.func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<10x10xf32> { @@ -1360,14 +1360,14 @@ func.func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, % // CHECK-DAG: [[BEGIN:%cst.*]] = arith.constant dense<-1> : tensor<1xi32> // CHECK-DAG: [[END:%cst.*]] = arith.constant dense<0> : tensor<1xi32> // CHECK-DAG: [[STRIDES:%cst.*]] = arith.constant dense<1> : tensor<1xi32> - // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> + // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32}> : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> } func.func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf_type.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.string> { %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64, offset = false} : (tensor<12x2x2x5x!tf_type.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.string> func.return %0 : tensor<1x2x2x5x!tf_type.string> // CHECK-LABEL: strided_slice_with_string - // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf_type.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.string> + // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<12x2x2x5x!tf_type.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.string> } func.func @strided_slice_with_unranked_input_and_i64_parameters(%arg0: tensor<*xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<*xf32> { @@ -1377,7 +1377,7 @@ func.func @strided_slice_with_unranked_input_and_i64_parameters(%arg0: tensor<*x // CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32> // CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32> // CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32> - // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<*xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xf32> + // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<*xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xf32> } func.func @strided_slice_with_i64_parameters(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<1x2x2x5xf32> { @@ -1387,7 +1387,7 @@ func.func @strided_slice_with_i64_parameters(%arg0: tensor<12x2x2x5xf32>, %arg1: // CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32> // CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32> // CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32> - // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> } func.func @strided_slice_with_i64_constant_attributes(%arg0: tensor<10x10x10xf32>) -> tensor<10x10xf32> { @@ -1400,21 +1400,21 @@ func.func @strided_slice_with_i64_constant_attributes(%arg0: tensor<10x10x10xf32 // CHECK-DAG: [[BEGIN:%cst.*]] = arith.constant dense<-1> : tensor<1xi32> // CHECK-DAG: [[END:%cst.*]] = arith.constant dense<0> : tensor<1xi32> // CHECK-DAG: [[STRIDES:%cst.*]] = arith.constant dense<1> : tensor<1xi32> - // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> + // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32}> : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> } func.func @strided_slice_non_zero_ellipsis_mask(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xf32> { %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64, offset = false} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> func.return %0 : tensor<1x2x2x5xf32> // CHECK-LABEL: strided_slice_non_zero_ellipsis_mask - // CHECK: %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 1 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + // CHECK: %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) <{begin_mask = 0 : i32, ellipsis_mask = 1 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> } func.func @strided_slice_non_zero_new_axis_mask(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xf32> { %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 2 : i64, shrink_axis_mask = 0 : i64, offset = false} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> func.return %0 : tensor<1x2x2x5xf32> // CHECK-LABEL: strided_slice_non_zero_new_axis_mask - // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 2 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 2 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> } func.func @strided_slice_big_dims(%arg0: tensor<5x6x7xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>, %arg3: tensor<3xi32>) -> tensor<1x1x5x6x7xf32> { @@ -1437,7 +1437,7 @@ func.func @mirror_pad(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor { func.return %0#0 : tensor // CHECK-LABEL: mirror_pad - // CHECK: "tfl.mirror_pad"(%arg0, %arg1) {mode = #tfl} : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor + // CHECK: "tfl.mirror_pad"(%arg0, %arg1) <{mode = #tfl}> : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor // CHECK: return } @@ -1447,7 +1447,7 @@ func.func @mirror_pad_reflect(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor // CHECK-LABEL: mirror_pad_reflect - // CHECK: "tfl.mirror_pad"(%arg0, %arg1) {mode = #tfl} : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor + // CHECK: "tfl.mirror_pad"(%arg0, %arg1) <{mode = #tfl}> : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor // CHECK: return } @@ -1512,7 +1512,7 @@ func.func @ReverseSequence(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tens func.return %0: tensor<2x3xf32> // CHECK-LABEL: ReverseSequence -// CHECK: "tfl.reverse_sequence"(%arg0, %arg1) {batch_dim = 0 : i32, seq_dim = 0 : i32} : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32> +// CHECK: "tfl.reverse_sequence"(%arg0, %arg1) <{batch_dim = 0 : i32, seq_dim = 0 : i32}> : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32> } func.func @LRN(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> { @@ -1520,7 +1520,7 @@ func.func @LRN(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> { func.return %0: tensor<2x3x4x5xf32> // CHECK-LABEL: LRN - // CHECK: "tfl.local_response_normalization"(%arg0) {alpha = 1.000000e+00 : f32, beta = 5.000000e-01 : f32, bias = 1.000000e+00 : f32, radius = 5 : i32} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> + // CHECK: "tfl.local_response_normalization"(%arg0) <{alpha = 1.000000e+00 : f32, beta = 5.000000e-01 : f32, bias = 1.000000e+00 : f32, radius = 5 : i32}> : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> // CHECK: return %0 : tensor<2x3x4x5xf32> } @@ -1529,7 +1529,7 @@ func.func @OneHot(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, func.return %0: tensor<*xf32> // CHECK-LABEL: OneHot -// CHECK: "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xf32> +// CHECK: "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) <{axis = -1 : i32}> : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xf32> } func.func @argmax(%arg0: tensor<3xi32>, %arg1: tensor) -> tensor { @@ -1554,7 +1554,7 @@ func.func @space_to_depth(%arg0: tensor<1x2x2x1xf32>) -> tensor { // CHECK-LABEL: space_to_depth // CHECK: %[[ARG:.*]]: tensor<1x2x2x1xf32> - // CHECK: "tfl.space_to_depth"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x2x2x1xf32>) -> tensor + // CHECK: "tfl.space_to_depth"(%[[ARG]]) <{block_size = 2 : i32}> : (tensor<1x2x2x1xf32>) -> tensor } func.func @round(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { @@ -1571,14 +1571,14 @@ func.func @resize_nearest_neighbor(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor %0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor func.return %0 : tensor // CHECK-LABEL: resize_nearest_neighbor - // CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor + // CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) <{align_corners = true, half_pixel_centers = false}> : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor } func.func @resize_nearest_neighbor_with_half_pixel_centers(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor { %0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = false, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor func.return %0 : tensor // CHECK-LABEL: resize_nearest_neighbor_with_half_pixel_centers - // CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = false, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor + // CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) <{align_corners = false, half_pixel_centers = true}> : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor } func.func @sparse_to_dense_with_scalar_sparse_indices(%arg0: tensor, %arg1: tensor<3xi32>, %arg2: tensor, %arg3: tensor) -> tensor { @@ -1643,7 +1643,7 @@ func.func @depth_to_space(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> { // CHECK-LABEL: depth_to_space // CHECK: %[[ARG:.*]]: tensor<1x1x1x4xf32> - // CHECK: "tfl.depth_to_space"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> + // CHECK: "tfl.depth_to_space"(%[[ARG]]) <{block_size = 2 : i32}> : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> } func.func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor<2xi32> { @@ -1687,9 +1687,9 @@ func.func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf3 // CHECK-LABEL: conv2d_backprop_input // CHECK: %[[CST:.*]] = arith.constant dense<[2, 0, 1, 3]> : tensor<4xi32> // CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> - // CHECK: %[[CST_0:.*]] = "tfl.no_value"() {value} : () -> none - // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> - // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> + // CHECK: %[[CST_0:.*]] = "tfl.no_value"() <{value}> : () -> none + // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) <{fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> + // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) <{fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> // CHECK: %[[RESULT:.*]] = tfl.add %[[ARG1]], %[[ARG3]] {fused_activation_function = "NONE"} : tensor<15x28x28x1xf32> // CHECK: return %[[RESULT]] : tensor<15x28x28x1xf32> } @@ -1734,7 +1734,7 @@ func.func @reciprocal_f32(%arg0: tensor<8xf32>) -> tensor<8xf32> { // CHECK-LABEL: reciprocal_f32 // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor -// CHECK: tfl.div(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor, tensor<8xf32>) -> tensor<8xf32> +// CHECK: tfl.div(%cst, %arg0) <{fused_activation_function = "NONE"}> : (tensor, tensor<8xf32>) -> tensor<8xf32> // CHECK: return } @@ -1744,7 +1744,7 @@ func.func @reciprocal_i32(%arg0: tensor<8xi32>) -> tensor<8xi32> { // CHECK-LABEL: reciprocal_i32 // CHECK: %cst = arith.constant dense<1> : tensor -// CHECK: tfl.div(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor, tensor<8xi32>) -> tensor<8xi32> +// CHECK: tfl.div(%cst, %arg0) <{fused_activation_function = "NONE"}> : (tensor, tensor<8xi32>) -> tensor<8xi32> // CHECK: return } @@ -1763,8 +1763,8 @@ func.func @LstmWithoutProjection(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x16xf // CHECK-DAG: [[VAL_2:%.*]] = arith.constant dense<0.000000e+00> : tensor<16x16xf32> // CHECK-DAG: [[VAL_3:%.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> // CHECK-DAG: [[VAL_4:%.*]] = arith.constant dense<0.000000e+00> : tensor<1x16xf32> -// CHECK-DAG: [[VAL_5:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_6:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_5]], [[VAL_5]], [[VAL_4]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_5]], [[VAL_5]]) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, none, none, tensor<1x16xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x16xf32> +// CHECK-DAG: [[VAL_5:%.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: [[VAL_6:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_5]], [[VAL_5]], [[VAL_4]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_5]], [[VAL_5]]) <{cell_clip = 0.000000e+00 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true}> : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, none, none, tensor<1x16xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x16xf32> // CHECK: return [[VAL_6]] : tensor<28x1x16xf32> // CHECK: } @@ -1788,8 +1788,8 @@ func.func @LstmWithProjection(%arg: tensor<28x1x16xf32>) -> (tensor<28x1x8xf32>) // CHECK-DAG: [[VAL_11:%.*]] = arith.constant dense<0.000000e+00> : tensor<1x16xf32> // CHECK-DAG: [[VAL_12:%.*]] = arith.constant dense<0.000000e+00> : tensor<8x16xf32> // CHECK-DAG: [[VAL_13:%.*]] = arith.constant dense<0.000000e+00> : tensor<1x8xf32> -// CHECK-DAG: [[VAL_14:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32> +// CHECK-DAG: [[VAL_14:%.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) <{cell_clip = 0.000000e+00 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true}> : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32> // CHECK: return [[VAL_15]] : tensor<28x1x8xf32> // CHECK: } @@ -1805,7 +1805,7 @@ func.func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) // CHECK-DAG: [[VAL_1:%.*]] = arith.constant dense<0.000000e+00> : tensor<28x28xf32> // CHECK-DAG: [[VAL_2:%.*]] = arith.constant dense<0.000000e+00> : tensor<28xf32> // CHECK-DAG: [[VAL_3:%.*]] = arith.constant dense<0.000000e+00> : tensor<1x28xf32> -// CHECK: [[VAL_4:%.*]] = "tfl.unidirectional_sequence_rnn"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_3]]) {fused_activation_function = "TANH", time_major = true} : (tensor<28x1x28xf32>, tensor<28x28xf32>, tensor<28x28xf32>, tensor<28xf32>, tensor<1x28xf32>) -> tensor<28x1x28xf32> +// CHECK: [[VAL_4:%.*]] = "tfl.unidirectional_sequence_rnn"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_3]]) <{fused_activation_function = "TANH", time_major = true}> : (tensor<28x1x28xf32>, tensor<28x28xf32>, tensor<28x28xf32>, tensor<28xf32>, tensor<1x28xf32>) -> tensor<28x1x28xf32> // CHECK: return [[VAL_4]] : tensor<28x1x28xf32> // CHECK: } @@ -1832,7 +1832,7 @@ func.func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> t (tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32> func.return %0 : tensor<10x17xf32> // CHECK-LABEL: matmul_batch -// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32> +// CHECK: "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32> } func.func @matmul_batchv2(%arg0: tensor<2x10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<2x10x17xf32> { @@ -1840,7 +1840,7 @@ func.func @matmul_batchv2(%arg0: tensor<2x10x15xf32>, %arg1: tensor<15x17xf32>) (tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32> func.return %0 : tensor<2x10x17xf32> // CHECK-LABEL: matmul_batchv2 -// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32> +// CHECK: "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32> } func.func @matmul_batchv3(%arg0: tensor<2x10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<2x10x17xf32> { @@ -1848,7 +1848,7 @@ func.func @matmul_batchv3(%arg0: tensor<2x10x15xf32>, %arg1: tensor<15x17xf32>) (tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32> func.return %0 : tensor<2x10x17xf32> // CHECK-LABEL: matmul_batchv3 -// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32> +// CHECK: "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32> } func.func @matmul_batchv3_int8(%arg0: tensor<2x10x15xi8>, %arg1: tensor<15x17xi8>) -> tensor<2x10x17xi32> { @@ -1856,7 +1856,7 @@ func.func @matmul_batchv3_int8(%arg0: tensor<2x10x15xi8>, %arg1: tensor<15x17xi8 (tensor<2x10x15xi8>, tensor<15x17xi8>) -> tensor<2x10x17xi32> func.return %0 : tensor<2x10x17xi32> // CHECK-LABEL: matmul_batchv3_int8 -// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x10x15xi8>, tensor<15x17xi8>) -> tensor<2x10x17xi32> +// CHECK: "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<2x10x15xi8>, tensor<15x17xi8>) -> tensor<2x10x17xi32> } func.func @matmul_batchv2_unknown_dim(%arg0: tensor, %arg1: tensor<15x17xf32>) -> tensor { @@ -1864,7 +1864,7 @@ func.func @matmul_batchv2_unknown_dim(%arg0: tensor, %arg1: tensor< (tensor, tensor<15x17xf32>) -> tensor func.return %0 : tensor // CHECK-LABEL: matmul_batchv2_unknown_dim -// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor, tensor<15x17xf32>) -> tensor +// CHECK: "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor, tensor<15x17xf32>) -> tensor } func.func @matmul_batchv3_unknown_dim(%arg0: tensor, %arg1: tensor<15x17xf32>) -> tensor { @@ -1872,7 +1872,7 @@ func.func @matmul_batchv3_unknown_dim(%arg0: tensor, %arg1: tensor< (tensor, tensor<15x17xf32>) -> tensor func.return %0 : tensor // CHECK-LABEL: matmul_batchv3_unknown_dim -// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor, tensor<15x17xf32>) -> tensor +// CHECK: "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor, tensor<15x17xf32>) -> tensor } func.func @matmul_batchv3_unknown_dim_bf16(%arg0: tensor, %arg1: tensor<5x6xf32>) -> tensor { @@ -1883,7 +1883,7 @@ func.func @matmul_batchv3_unknown_dim_bf16(%arg0: tensor, %arg1: ten func.return %2 : tensor // CHECK-LABEL: matmul_batchv3_unknown_dim_bf16 // CHECK: [[CST:%.*]] = "tfl.cast"(%arg0) : (tensor) -> tensor -// CHECK: [[BMM:%.*]] = "tfl.batch_matmul"([[CST]], %arg1) {adj_x = false, adj_y = false} : (tensor, tensor<5x6xf32>) -> tensor +// CHECK: [[BMM:%.*]] = "tfl.batch_matmul"([[CST]], %arg1) <{adj_x = false, adj_y = false}> : (tensor, tensor<5x6xf32>) -> tensor // CHECK: "tfl.cast"([[BMM]]) : (tensor) -> tensor } @@ -1918,12 +1918,12 @@ func.func @test5DAddWithImplicitBroadcast(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : %0 = "tf.Add"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> func.return %0 : tensor<1x1x1x3x4xi32> // CHECK-LABEL: test5DAddWithImplicitBroadcast -// CHECK: %0 = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> +// CHECK: %0 = tfl.add(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> } func.func @test6DAddWithImplicitBroadcast(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { // CHECK-LABEL: test6DAddWithImplicitBroadcast -// CHECK: %0 = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> +// CHECK: %0 = tfl.add(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> %0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> func.return %0 : tensor<1x2x3x4x5x6xi32> } @@ -1941,12 +1941,12 @@ func.func @test5DSubWithImplicitBroadcast(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : %0 = "tf.Sub"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> func.return %0 : tensor<1x1x1x3x4xi32> // CHECK-LABEL: test5DSubWithImplicitBroadcast -// CHECK: %0 = tfl.sub(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> +// CHECK: %0 = tfl.sub(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> } func.func @test6DSubWithImplicitBroadcast(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { // CHECK-LABEL: test6DSubWithImplicitBroadcast -// CHECK: %0 = tfl.sub(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> +// CHECK: %0 = tfl.sub(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> %0 = "tf.Sub"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> func.return %0 : tensor<1x2x3x4x5x6xi32> } @@ -1964,12 +1964,12 @@ func.func @test5DMulWithImplicitBroadcast(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : %0 = "tf.Mul"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> func.return %0 : tensor<1x1x1x3x4xi32> // CHECK-LABEL: test5DMulWithImplicitBroadcast -// CHECK: %0 = tfl.mul(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> +// CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> } func.func @test6DMulWithImplicitBroadcast(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { // CHECK-LABEL: test6DMulWithImplicitBroadcast -// CHECK: %0 = tfl.mul(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> +// CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> func.return %0 : tensor<1x2x3x4x5x6xi32> } @@ -2148,7 +2148,7 @@ func.func @cumsum(%arg0: tensor<3x3xf32>, %arg1: tensor) -> tensor<3x3xf32> %0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> func.return %0 : tensor<3x3xf32> // CHECK-LABEL: cumsum - // CHECK: "tfl.cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> + // CHECK: "tfl.cumsum"(%arg0, %arg1) <{exclusive = false, reverse = false}> : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> } func.func @cumsum_i64(%arg0: tensor<3x3xf32>, %arg1: tensor) -> tensor<3x3xf32> { @@ -2314,8 +2314,8 @@ func.func @conv3d_valid(%arg0: tensor,%arg1: tensor // CHECK-LABEL: conv3d_valid - // CHECK: %[[CST:.*]] = "tfl.no_value"() {value} : () -> none - // CHECK: [[BCT:%.*]] = "tfl.conv_3d"(%arg0, %arg1, %[[CST]]) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor, tensor, none) -> tensor + // CHECK: %[[CST:.*]] = "tfl.no_value"() <{value}> : () -> none + // CHECK: [[BCT:%.*]] = "tfl.conv_3d"(%arg0, %arg1, %[[CST]]) <{dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor, tensor, none) -> tensor // CHECK: return [[BCT]] : tensor } @@ -2359,7 +2359,7 @@ func.func @all(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { func.return %0 : tensor // CHECK-LABEL:all -// CHECK: "tfl.reduce_all"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor +// CHECK: "tfl.reduce_all"(%arg0, %arg1) <{keep_dims = false}> : (tensor<2x2xi1>, tensor) -> tensor } func.func @all_i64axes(%arg0: tensor<8x16x16xi1>, %arg1: tensor<2xi64>) -> tensor { @@ -2368,7 +2368,7 @@ func.func @all_i64axes(%arg0: tensor<8x16x16xi1>, %arg1: tensor<2xi64>) -> tenso // CHECK-LABEL: all_i64axes // CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32> - // CHECK: "tfl.reduce_all"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi32>) -> tensor + // CHECK: "tfl.reduce_all"(%arg0, %[[V0]]) <{keep_dims = false}> : (tensor<8x16x16xi1>, tensor<2xi32>) -> tensor } func.func @quantize_dequantize_v4(%arg0 : tensor) -> tensor { @@ -2378,7 +2378,7 @@ func.func @quantize_dequantize_v4(%arg0 : tensor) -> tensor { func.return %0 : tensor // CHECK-LABEL: quantize_dequantize_v4 -// CHECK: %[[QUANT:.*]] = "tfl.quantize"(%arg0) {qtype = tensor>} : (tensor) -> tensor> +// CHECK: %[[QUANT:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor>}> : (tensor) -> tensor> // CHECK: %[[DEQUANT:.*]] = "tfl.dequantize"(%[[QUANT]]) : (tensor>) -> tensor // CHECK: return %[[DEQUANT]] } @@ -2387,9 +2387,9 @@ func.func @conv3d_transpose(%arg0: tensor<2x5x6x8x2xf32>, %arg1: tensor<1x2x2x3x %0 = "tf.Conv3DBackpropInputV2"(%arg2, %arg1, %arg0) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "VALID", strides = [1, 2, 2, 2, 1]} : (tensor<5xi64>, tensor<1x2x2x3x2xf32>, tensor<2x5x6x8x2xf32>) -> tensor func.return %0 : tensor // CHECK-LABEL: conv3d_transpose - // CHECK: %[[CST:.*]] = "tfl.no_value"() {value} : () -> none + // CHECK: %[[CST:.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK: %[[OUT_SHAPE:.*]] = "tfl.cast"(%arg2) : (tensor<5xi64>) -> tensor<5xi32> - // CHECK: %[[RESULT:.*]] = "tfl.conv_3d_transpose"(%[[OUT_SHAPE]], %arg1, %arg0, %[[CST]]) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<5xi32>, tensor<1x2x2x3x2xf32>, tensor<2x5x6x8x2xf32>, none) -> tensor + // CHECK: %[[RESULT:.*]] = "tfl.conv_3d_transpose"(%[[OUT_SHAPE]], %arg1, %arg0, %[[CST]]) <{dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<5xi32>, tensor<1x2x2x3x2xf32>, tensor<2x5x6x8x2xf32>, none) -> tensor // CHECK: return %[[RESULT]] : tensor } @@ -2464,7 +2464,7 @@ func.func @mul_with_unranked_lhs(%arg0: tensor<*xf32>, %arg1: tensor // CHECK-LABEL:mul_with_unranked_lhs - // CHECK: %0 = tfl.mul(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<*xf32>, tensor) -> tensor + // CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<*xf32>, tensor) -> tensor // CHECK: return %0 : tensor } @@ -2530,7 +2530,7 @@ func.func @Bucketize(%arg0: tensor<3x2xf32>) -> tensor<3x2xi32> { func.return %0: tensor<3x2xi32> // CHECK-LABEL: Bucketize -// CHECK: "tfl.bucketize"(%arg0) {boundaries = [1.000000e+00 : f32, 1.000000e+01 : f32, 1.000000e+02 : f32]} : (tensor<3x2xf32>) -> tensor<3x2xi32> +// CHECK: "tfl.bucketize"(%arg0) <{boundaries = [1.000000e+00 : f32, 1.000000e+01 : f32, 1.000000e+02 : f32]}> : (tensor<3x2xf32>) -> tensor<3x2xi32> } func.func @random_uniform_f32(%arg0: tensor<3xi32>) -> tensor { @@ -2538,7 +2538,7 @@ func.func @random_uniform_f32(%arg0: tensor<3xi32>) -> tensor { func.return %0 : tensor // CHECK-LABEL:random_uniform_f32 -// CHECK: "tfl.random_uniform"(%arg0) {seed = 0 : i64, seed2 = 0 : i64} : (tensor<3xi32>) -> tensor +// CHECK: "tfl.random_uniform"(%arg0) <{seed = 0 : i64, seed2 = 0 : i64}> : (tensor<3xi32>) -> tensor } func.func @random_standard_normal_f32(%arg0: tensor<3xi32>) -> tensor { @@ -2546,7 +2546,7 @@ func.func @random_standard_normal_f32(%arg0: tensor<3xi32>) -> tensor func.return %0 : tensor // CHECK-LABEL:random_standard_normal_f32 -// CHECK: "tfl.random_standard_normal"(%arg0) {seed = 0 : i64, seed2 = 0 : i64} : (tensor<3xi32>) -> tensor +// CHECK: "tfl.random_standard_normal"(%arg0) <{seed = 0 : i64, seed2 = 0 : i64}> : (tensor<3xi32>) -> tensor } func.func @multinomial_i64(%arg0: tensor<2xf32>, %arg1: tensor<1xi32>) -> tensor<10xi64> { @@ -2554,7 +2554,7 @@ func.func @multinomial_i64(%arg0: tensor<2xf32>, %arg1: tensor<1xi32>) -> tensor func.return %0 : tensor<10xi64> // CHECK-LABEL:multinomial_i64 -// CHECK: "tfl.multinomial"(%arg0, %arg1) {seed = 0 : i64, seed2 = 0 : i64} : (tensor<2xf32>, tensor<1xi32>) -> tensor<10xi64> +// CHECK: "tfl.multinomial"(%arg0, %arg1) <{seed = 0 : i64, seed2 = 0 : i64}> : (tensor<2xf32>, tensor<1xi32>) -> tensor<10xi64> } func.func @multinomial_i32(%arg0: tensor<2xf32>, %arg1: tensor<1xi32>) -> tensor<10xi32> { @@ -2562,7 +2562,7 @@ func.func @multinomial_i32(%arg0: tensor<2xf32>, %arg1: tensor<1xi32>) -> tensor func.return %0 : tensor<10xi32> // CHECK-LABEL:multinomial_i32 -// CHECK: "tfl.multinomial"(%arg0, %arg1) {seed = 0 : i64, seed2 = 0 : i64} : (tensor<2xf32>, tensor<1xi32>) -> tensor<10xi32> +// CHECK: "tfl.multinomial"(%arg0, %arg1) <{seed = 0 : i64, seed2 = 0 : i64}> : (tensor<2xf32>, tensor<1xi32>) -> tensor<10xi32> } func.func @dynamic_update_slice(%arg0: tensor<4x5xi32>, %arg1: tensor<1x5xi32>, %arg2: tensor<2xi32>) -> tensor<4x5xi32> { @@ -2683,7 +2683,7 @@ func.func @sigmoidGrad(%arg0: tensor, %arg1: tensor) -> tens func.return %0 : tensor // CHECK-LABEL: sigmoidGrad // CHECK-NEXT: [[ONE:%.+]] = arith.constant dense<1.000000e+00> : tensor -// CHECK-NEXT: [[SUB:%.+]] = tfl.sub([[ONE]], %arg0) {fused_activation_function = "NONE"} : (tensor, tensor) -> tensor +// CHECK-NEXT: [[SUB:%.+]] = tfl.sub([[ONE]], %arg0) <{fused_activation_function = "NONE"}> : (tensor, tensor) -> tensor // CHECK-NEXT: [[MUL0:%.+]] = tfl.mul %arg0, [[SUB]] {fused_activation_function = "NONE"} : tensor // CHECK-NEXT: [[MUL1:%.+]] = tfl.mul %arg1, [[MUL0]] {fused_activation_function = "NONE"} : tensor // CHECK: return [[MUL1]] @@ -2697,8 +2697,8 @@ func.func @batchmatmul2fullyconnected(%arg0: tensor<4x128x2xf32>) -> (tensor<4x1 // CHECK-LABEL: batchmatmul2fullyconnected // CHECK-DAG: %cst_0 = arith.constant dense<[1, 0]> : tensor<2xi32> // CHECK: %0 = "tfl.transpose"(%cst, %cst_0) : (tensor<2x1xf32>, tensor<2xi32>) -> tensor<1x2xf32> - // CHECK-DAG: %1 = "tfl.no_value"() {value} : () -> none - // CHECK: %2 = "tfl.fully_connected"(%arg0, %0, %1) {fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<4x128x2xf32>, tensor<1x2xf32>, none) -> tensor<4x128x1xf32> + // CHECK-DAG: %1 = "tfl.no_value"() <{value}> : () -> none + // CHECK: %2 = "tfl.fully_connected"(%arg0, %0, %1) <{fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<4x128x2xf32>, tensor<1x2xf32>, none) -> tensor<4x128x1xf32> // CHECK: return %2 : tensor<4x128x1xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize_jax_random.mlir b/tensorflow/compiler/mlir/lite/tests/legalize_jax_random.mlir index 76f453d1d3a8aa..fe1c86c3d3feb7 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize_jax_random.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize_jax_random.mlir @@ -4,7 +4,7 @@ // CHECK-LABEL: func @tfl_wrapped_jax_random_normal( // CHECK-SAME: %[[RNG:.*]]: tensor<2xui32>) -> tuple> { // CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<[3, 4]> : tensor<2xi32> -// CHECK: %[[VAL_1:.*]] = "tfl.custom"(%[[VAL_0]]) {custom_code = "RandomStandardNormal", custom_option = #tfl} : (tensor<2xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_1:.*]] = "tfl.custom"(%[[VAL_0]]) <{custom_code = "RandomStandardNormal", custom_option = #tfl}> : (tensor<2xi32>) -> tensor<3x4xf32> // CHECK: %[[VAL_2:.*]] = stablehlo.tuple %[[VAL_1]] : tuple> // CHECK: return %[[VAL_2]] : tuple> // CHECK: } @@ -20,7 +20,7 @@ func.func @tfl_wrapped_jax_random_normal(%arg0: tensor<2xui32>) -> tuple) -> tuple> { // CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<[1, 2]> : tensor<2xi32> -// CHECK: %[[VAL_1:.*]] = "tfl.custom"(%[[VAL_0]]) {custom_code = "RandomUniform", custom_option = #tfl} : (tensor<2xi32>) -> tensor<1x2xf32> +// CHECK: %[[VAL_1:.*]] = "tfl.custom"(%[[VAL_0]]) <{custom_code = "RandomUniform", custom_option = #tfl}> : (tensor<2xi32>) -> tensor<1x2xf32> // CHECK: %[[VAL_2:.*]] = stablehlo.tuple %[[VAL_1]] : tuple> // CHECK: return %[[VAL_2]] : tuple> // CHECK: } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir index 89b3a2b7caa079..10f051bf66c966 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir @@ -59,7 +59,7 @@ func.func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } -// IMPORT: "tfl.fake_quant"(%arg0) {max = 1.400000e+00 : f32, min = 3.000000e-01 : f32, narrow_range = false, num_bits = 6 : i32} +// IMPORT: "tfl.fake_quant"(%arg0) <{max = 1.400000e+00 : f32, min = 3.000000e-01 : f32, narrow_range = false, num_bits = 6 : i32}> %0 = "tfl.fake_quant"(%arg0) {num_bits = 6 : i32, narrow_range = false, min = 0.3:f32, max = 1.4:f32} : (tensor<4 x f32>) -> tensor<4 x f32> func.return %0 : tensor<4xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/modify_io_nodes.mlir b/tensorflow/compiler/mlir/lite/tests/modify_io_nodes.mlir index 686fcc5703552e..38d2a05904e418 100644 --- a/tensorflow/compiler/mlir/lite/tests/modify_io_nodes.mlir +++ b/tensorflow/compiler/mlir/lite/tests/modify_io_nodes.mlir @@ -9,39 +9,39 @@ func.func @modified(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x401408xf32> attr %2 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> %3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> %4 = "tfl.reshape"(%3, %cst) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -> tensor<1x401408x!quant.uniform> - %5 = "tfl.softmax"(%4) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> + %5 = "tfl.softmax"(%4) <{beta = 1.000000e+00 : f32}> : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> %6 = "tfl.dequantize"(%5) : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408xf32> func.return %6 : tensor<1x401408xf32> // CHECK-LABEL: func @modified(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x401408xf32> // CHECK-NEXT: %[[shape:.*]] = arith.constant dense<[1, 401408]> : tensor<2xi32> -// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> -// CHECK-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> -// CHECK-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> -// CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[q]], %[[cst1]], %[[cst2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> +// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x224x224x3x!quant.uniform>}> : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> +// CHECK-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}> : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> +// CHECK-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>}> : () -> tensor<32x!quant.uniform> +// CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[q]], %[[cst1]], %[[cst2]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> // CHECK-NEXT: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -> tensor<1x401408x!quant.uniform> -// CHECK-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> +// CHECK-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) <{beta = 1.000000e+00 : f32}> : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> // CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[softmax]]) : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408xf32> // CHECK-NEXT: return %[[dq]] : tensor<1x401408xf32> // INT8-LABEL: @modified(%arg0: tensor<1x224x224x3x!quant.uniform>) -> tensor<1x401408x!quant.uniform> // INT8-NEXT: %[[shape:.*]] = arith.constant dense<[1, 401408]> : tensor<2xi32> -// INT8-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> -// INT8-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> -// INT8-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[cst1]], %[[cst2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> +// INT8-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}> : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> +// INT8-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>}> : () -> tensor<32x!quant.uniform> +// INT8-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[cst1]], %[[cst2]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> // INT8-NEXT: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -> tensor<1x401408x!quant.uniform> -// INT8-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> +// INT8-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) <{beta = 1.000000e+00 : f32}> : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> // INT8-NEXT: return %[[softmax]] : tensor<1x401408x!quant.uniform> // UINT8-LABEL: func @modified(%arg0: tensor<1x224x224x3x!quant.uniform>) -> tensor<1x401408x!quant.uniform> // UINT8-NEXT: %[[shape:.*]] = arith.constant dense<[1, 401408]> : tensor<2xi32> -// UINT8-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3x!quant.uniform> -// UINT8-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> -// UINT8-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> -// UINT8-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[q]], %[[cst1]], %[[cst2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> +// UINT8-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x224x224x3x!quant.uniform>}> : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3x!quant.uniform> +// UINT8-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}> : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> +// UINT8-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>}> : () -> tensor<32x!quant.uniform> +// UINT8-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[q]], %[[cst1]], %[[cst2]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> // UINT8-NEXT: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -> tensor<1x401408x!quant.uniform> -// UINT8-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> -// UINT8-NEXT: %[[dq:.*]] = "tfl.quantize"(%[[softmax]]) {qtype = tensor<1x401408x!quant.uniform>} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> +// UINT8-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) <{beta = 1.000000e+00 : f32}> : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> +// UINT8-NEXT: %[[dq:.*]] = "tfl.quantize"(%[[softmax]]) <{qtype = tensor<1x401408x!quant.uniform>}> : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> // UINT8-NEXT: return %[[dq]] : tensor<1x401408x!quant.uniform> } @@ -52,40 +52,40 @@ func.func @not_modified(%arg0: tensor, %arg1: tensor<1x224x224x3xf32>) -> ( %2 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> %3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> %4 = "tfl.reshape"(%3, %cst) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -> tensor<1x401408x!quant.uniform> - %5 = "tfl.softmax"(%4) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> + %5 = "tfl.softmax"(%4) <{beta = 1.000000e+00 : f32}> : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> %6 = "tfl.dequantize"(%5) : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408xf32> func.return %6, %arg1 : tensor<1x401408xf32>, tensor<1x224x224x3xf32> // CHECK-LABEL: func @not_modified(%arg0: tensor, %arg1: tensor<1x224x224x3xf32>) -> (tensor<1x401408xf32>, tensor<1x224x224x3xf32>) // CHECK-NEXT: %[[shape:.*]] = arith.constant dense<[1, 401408]> : tensor<2xi32> -// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> -// CHECK-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> -// CHECK-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> -// CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[q]], %[[cst1]], %[[cst2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> +// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x224x224x3x!quant.uniform>}> : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> +// CHECK-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}> : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> +// CHECK-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>}> : () -> tensor<32x!quant.uniform> +// CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[q]], %[[cst1]], %[[cst2]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> // CHECK-NEXT: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -> tensor<1x401408x!quant.uniform> -// CHECK-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> +// CHECK-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) <{beta = 1.000000e+00 : f32}> : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> // CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[softmax]]) : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408xf32> // CHECK-NEXT: return %[[dq]], %arg1 : tensor<1x401408xf32>, tensor<1x224x224x3xf32> // INT8-LABEL: @not_modified(%arg0: tensor, %arg1: tensor<1x224x224x3xf32>) -> (tensor<1x401408x!quant.uniform>, tensor<1x224x224x3xf32>) // INT8-NEXT: %[[shape:.*]] = arith.constant dense<[1, 401408]> : tensor<2xi32> -// INT8-NEXT: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> -// INT8-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> -// INT8-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> -// INT8-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[q]], %[[cst1]], %[[cst2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> +// INT8-NEXT: %[[q:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x224x224x3x!quant.uniform>}> : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> +// INT8-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}> : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> +// INT8-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>}> : () -> tensor<32x!quant.uniform> +// INT8-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[q]], %[[cst1]], %[[cst2]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> // INT8-NEXT: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -> tensor<1x401408x!quant.uniform> -// INT8-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> +// INT8-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) <{beta = 1.000000e+00 : f32}> : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> // INT8-NEXT: return %[[softmax]], %arg1 : tensor<1x401408x!quant.uniform>, tensor<1x224x224x3xf32> // UINT8-LABEL: func @not_modified(%arg0: tensor, %arg1: tensor<1x224x224x3xf32>) -> (tensor<1x401408x!quant.uniform>, tensor<1x224x224x3xf32>) // UINT8-NEXT: %[[shape:.*]] = arith.constant dense<[1, 401408]> : tensor<2xi32> -// UINT8-NEXT: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> -// UINT8-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> -// UINT8-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> -// UINT8-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[q]], %[[cst1]], %[[cst2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> +// UINT8-NEXT: %[[q:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x224x224x3x!quant.uniform>}> : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> +// UINT8-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}> : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> +// UINT8-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>}> : () -> tensor<32x!quant.uniform> +// UINT8-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[q]], %[[cst1]], %[[cst2]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> // UINT8-NEXT: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -> tensor<1x401408x!quant.uniform> -// UINT8-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> -// UINT8-NEXT: %[[dq:.*]] = "tfl.quantize"(%[[softmax]]) {qtype = tensor<1x401408x!quant.uniform>} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> +// UINT8-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) <{beta = 1.000000e+00 : f32}> : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> +// UINT8-NEXT: %[[dq:.*]] = "tfl.quantize"(%[[softmax]]) <{qtype = tensor<1x401408x!quant.uniform>}> : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> // UINT8-NEXT: return %[[dq]], %arg1 : tensor<1x401408x!quant.uniform>, tensor<1x224x224x3xf32> } @@ -96,7 +96,7 @@ func.func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x401408xf32> { %2 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> %3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> %4 = "tfl.reshape"(%3, %cst) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -> tensor<1x401408x!quant.uniform> - %5 = "tfl.softmax"(%4) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> + %5 = "tfl.softmax"(%4) <{beta = 1.000000e+00 : f32}> : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> %6 = "tfl.dequantize"(%5) : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408xf32> func.return %6 : tensor<1x401408xf32> @@ -112,7 +112,7 @@ func.func @non_entry_funciton(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x401408 %2 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> %3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> %4 = "tfl.reshape"(%3, %cst) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -> tensor<1x401408x!quant.uniform> - %5 = "tfl.softmax"(%4) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> + %5 = "tfl.softmax"(%4) <{beta = 1.000000e+00 : f32}> : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> %6 = "tfl.dequantize"(%5) : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408xf32> func.return %6 : tensor<1x401408xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 20c4a031157a89..fa69cd46017f8f 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -401,7 +401,7 @@ func.func @testAddWithI64Broadcasting(tensor< 2x3xi64>, tensor<3xi64>) -> tensor // CHECK-LABEL: add_with_i32_five_dim_broadcasting func.func @add_with_i32_five_dim_broadcasting(tensor<1x1x1x1x1xi32>, tensor<1xi32>) -> tensor<1x1x1x1x1xi32> { ^bb0(%arg0: tensor<1x1x1x1x1xi32>, %arg1: tensor<1xi32>): - // CHECK: tfl.add(%arg0, %arg1) {fused_activation_function = "RELU6"} + // CHECK: tfl.add(%arg0, %arg1) <{fused_activation_function = "RELU6"}> %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x1x1x1x1xi32>, tensor<1xi32>) -> tensor<1x1x1x1x1xi32> func.return %0#0 : tensor<1x1x1x1x1xi32> } @@ -420,7 +420,7 @@ func.func @add_with_quantized_i16_broadcasting(tensor<2x2xf32>, tensor<1xf32>) - // CHECK-LABEL: sub_with_i32_five_dim_broadcasting func.func @sub_with_i32_five_dim_broadcasting(tensor<1x1x1x1x1xi32>, tensor<1xi32>) -> tensor<1x1x1x1x1xi32> { ^bb0(%arg0: tensor<1x1x1x1x1xi32>, %arg1: tensor<1xi32>): - // CHECK: tfl.sub(%arg0, %arg1) {fused_activation_function = "RELU6"} + // CHECK: tfl.sub(%arg0, %arg1) <{fused_activation_function = "RELU6"}> %0 = "tfl.sub"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x1x1x1x1xi32>, tensor<1xi32>) -> tensor<1x1x1x1x1xi32> func.return %0#0 : tensor<1x1x1x1x1xi32> } @@ -438,7 +438,7 @@ func.func @sub_with_quantized_i8_five_dim_broadcasting(tensor<1x1x1x1x1xf32>, te // CHECK-LABEL: mul_with_i32_five_dim_broadcasting func.func @mul_with_i32_five_dim_broadcasting(tensor<1x1x1x1x1xi32>, tensor<1xi32>) -> tensor<1x1x1x1x1xi32> { ^bb0(%arg0: tensor<1x1x1x1x1xi32>, %arg1: tensor<1xi32>): - // CHECK: tfl.mul(%arg0, %arg1) {fused_activation_function = "RELU6"} + // CHECK: tfl.mul(%arg0, %arg1) <{fused_activation_function = "RELU6"}> %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x1x1x1x1xi32>, tensor<1xi32>) -> tensor<1x1x1x1x1xi32> func.return %0#0 : tensor<1x1x1x1x1xi32> } @@ -448,7 +448,7 @@ func.func @mul_with_i32_five_dim_broadcasting(tensor<1x1x1x1x1xi32>, tensor<1xi3 // CHECK-LABEL: mul_with_quantized_i16_five_dim_broadcasting func.func @mul_with_quantized_i16_five_dim_broadcasting(tensor<1x1x1x1x1x!quant.any>, tensor<1x!quant.any>) -> tensor<1x1x1x1x1x!quant.any> { ^bb0(%arg0: tensor<1x1x1x1x1x!quant.any>, %arg1: tensor<1x!quant.any>): - // CHECK: tfl.mul(%arg0, %arg1) {fused_activation_function = "RELU6"} + // CHECK: tfl.mul(%arg0, %arg1) <{fused_activation_function = "RELU6"}> %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x1x1x1x1x!quant.any>, tensor<1x!quant.any>) -> tensor<1x1x1x1x1x!quant.any> func.return %0#0 : tensor<1x1x1x1x1x!quant.any> } @@ -467,7 +467,7 @@ func.func @mul_with_quantized_i16_to_uint8_broadcasting(tensor<1x1x!quant.any, tensor) -> tensor> { ^bb0(%arg0: tensor, %arg1: tensor): - // CHECK: tfl.mul(%arg0, %arg1) {fused_activation_function = "RELU6"} + // CHECK: tfl.mul(%arg0, %arg1) <{fused_activation_function = "RELU6"}> %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}: (tensor, tensor) -> tensor> func.return %0#0 : tensor> } @@ -614,7 +614,7 @@ func.func @testConv2D4DBias(tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tenso // CHECK-LABEL: testFakeQuant func.func @testFakeQuant(tensor, f32, f32) -> tensor { ^bb0(%arg0: tensor, %arg1: f32, %arg2: f32): - // CHECK: "tfl.fake_quant"(%arg0) {max = 1.400000e+00 : f32, min = 3.000000e-01 : f32, narrow_range = false, num_bits = 6 : i32} : (tensor) -> tensor + // CHECK: "tfl.fake_quant"(%arg0) <{max = 1.400000e+00 : f32, min = 3.000000e-01 : f32, narrow_range = false, num_bits = 6 : i32}> : (tensor) -> tensor %1 = "tfl.fake_quant"(%arg0) {num_bits = 6 : i32, narrow_range = false, min = 0.3:f32, max = 1.4:f32} : (tensor) -> tensor func.return %1 : tensor } @@ -622,7 +622,7 @@ func.func @testFakeQuant(tensor, f32, f32) -> tensor { // CHECK-LABEL: testQuantize func.func @testQuantize(tensor) -> tensor> { ^bb0(%arg0: tensor): - // CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor>} + // CHECK: %0 = "tfl.quantize"(%arg0) <{qtype = tensor>}> %0 = "tfl.quantize"(%arg0) {qtype = tensor>} : (tensor) -> tensor> func.return %0 : tensor> } @@ -738,7 +738,7 @@ func.func @testPadding(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf3 // CHECK-LABEL: testMaxPool2D func.func @testMaxPool2D(tensor<256x32x32x3xf32>) -> tensor { ^bb0(%arg0: tensor<256x32x32x3xf32>): - // CHECK: "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>) -> tensor + // CHECK: "tfl.max_pool_2d"(%arg0) <{filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<256x32x32x3xf32>) -> tensor %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>) -> tensor func.return %0 : tensor } @@ -748,7 +748,7 @@ func.func @testMaxPool2D(tensor<256x32x32x3xf32>) -> tensor { // CHECK-LABEL: testMaxPool2DQuantized func.func @testMaxPool2DQuantized(tensor<256x32x32x3x!quant.uniform>) -> tensor> { ^bb0(%arg0: tensor<256x32x32x3x!quant.uniform>): - // CHECK: "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} + // CHECK: "tfl.max_pool_2d"(%arg0) <{filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3x!quant.uniform>) -> tensor> func.return %0 : tensor> } @@ -824,7 +824,7 @@ func.func @testLogisticWithWrongInputType(tensor) -> tensor { // CHECK-LABEL: testUnidirectionalSequenceRnn func.func @testUnidirectionalSequenceRnn(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor { - // CHECK: "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK: "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) <{fused_activation_function = "NONE", time_major = false}> : (tensor, tensor, tensor, tensor, tensor) -> tensor %0 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -833,7 +833,7 @@ func.func @testUnidirectionalSequenceRnn(%arg0: tensor, %arg1: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: none, %arg17: none, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, none, none, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) <{fused_activation_function = "NONE", time_major = false}> : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, none, none, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, none, none, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -842,7 +842,7 @@ func.func @testUnidirectionalSequenceLstmWithoutProjection(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) <{fused_activation_function = "NONE", time_major = false}> : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -860,7 +860,7 @@ func.func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tens %arg20: none, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, // CHECK-SAME: %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, - // CHECK-SAME: %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : + // CHECK-SAME: %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) <{cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false}> : // CHECK-SAME: (tensor, // CHECK-SAME: none, tensor, tensor, tensor, none, tensor, tensor, tensor, // CHECK-SAME: none, tensor, tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, @@ -879,7 +879,7 @@ func.func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tens // CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates func.func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) <{cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false}> : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -900,16 +900,17 @@ func.func @testLstmIntermediates(%arg0: tensor<1x528x!quant.uniform none %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform>, input_to_forget_intermediate = tensor<0x!quant.uniform>, input_to_cell_intermediate = tensor<0x!quant.uniform>, input_to_output_intermediate = tensor<0x!quant.uniform>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0075630000792443752:2>>, kernel_type = #tfl, proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> func.return %0 : tensor<1x640x!quant.uniform> -// CHECK: %[[RES0:.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: %[[RES1:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %[[RES0]], %[[RES0]], %[[RES0]], %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({ -// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform>, input_to_forget_intermediate = tensor<0x!quant.uniform>, input_to_input_intermediate = tensor<0x!quant.uniform>, input_to_output_intermediate = tensor<0x!quant.uniform>, kernel_type = #tfl, proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> +// CHECK: %[[RES0:.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: %[[RES1:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %[[RES0]], %[[RES0]], %[[RES0]], %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) +// CHECK-SAME: <{cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform>, input_to_forget_intermediate = tensor<0x!quant.uniform>, input_to_input_intermediate = tensor<0x!quant.uniform>, input_to_output_intermediate = tensor<0x!quant.uniform>, kernel_type = #tfl, proj_clip = 0.00999999977 : f32}> ({ +// CHECK: }) : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> } // ----- // CHECK-LABEL: testBidirectionalSequenceLstm func.func @testBidirectionalSequenceLstm(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor, %arg24: tensor, %arg25: tensor, %arg26: tensor, %arg27: tensor, %arg28: tensor, %arg29: tensor, %arg30: tensor, %arg31: tensor, %arg32: tensor, %arg33: tensor, %arg34: tensor, %arg35: tensor, %arg36: tensor, %arg37: tensor, %arg38: tensor, %arg39: tensor, %arg40: tensor, %arg41: tensor, %arg42: tensor, %arg43: tensor, %arg44: tensor, %arg45: tensor, %arg46: tensor, %arg47: tensor) -> tensor { - // CHECK: "tfl.bidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23, %arg24, %arg25, %arg26, %arg27, %arg28, %arg29, %arg30, %arg31, %arg32, %arg33, %arg34, %arg35, %arg36, %arg37, %arg38, %arg39, %arg40, %arg41, %arg42, %arg43, %arg44, %arg45, %arg46, %arg47) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", merge_outputs = true, time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor) + // CHECK: "tfl.bidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23, %arg24, %arg25, %arg26, %arg27, %arg28, %arg29, %arg30, %arg31, %arg32, %arg33, %arg34, %arg35, %arg36, %arg37, %arg38, %arg39, %arg40, %arg41, %arg42, %arg43, %arg44, %arg45, %arg46, %arg47) <{cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", merge_outputs = true, time_major = false}> : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor) %0:2 = "tfl.bidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23, %arg24, %arg25, %arg26, %arg27, %arg28, %arg29, %arg30, %arg31, %arg32, %arg33, %arg34, %arg35, %arg36, %arg37, %arg38, %arg39, %arg40, %arg41, %arg42, %arg43, %arg44, %arg45, %arg46, %arg47) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", merge_outputs = true, time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor) func.return %0#0 : tensor } @@ -922,9 +923,10 @@ func.func @testLstmQuantizedType(%arg0: tensor<1x528x!quant.uniform, proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> func.return %0 : tensor<1x640x!quant.uniform> - // CHECK: %[[RES0:.*]] = "tfl.no_value"() {value} : () -> none - // CHECK: %[[RES1:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %[[RES0]], %[[RES0]], %[[RES0]], %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({ - // CHECK-NEXT: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = #tfl, proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> + // CHECK: %[[RES0:.*]] = "tfl.no_value"() <{value}> : () -> none + // CHECK: %[[RES1:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %[[RES0]], %[[RES0]], %[[RES0]], %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) + // CHECK-SAME: <{cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = #tfl, proj_clip = 0.00999999977 : f32}> ({ + // CHECK-NEXT: }) : (tensor<1x528x!quant.uniform>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<640x2048x!quant.uniform:f32, 0.021174000576138496>>, tensor<640x!quant.uniform>, tensor<1x640x!quant.uniform>, tensor<1x2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>, tensor<2048x!quant.uniform>) -> tensor<1x640x!quant.uniform> // CHECK: return %[[RES1]] } @@ -933,7 +935,8 @@ func.func @testLstmQuantizedType(%arg0: tensor<1x528x!quant.uniform, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) - // CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = #tfl} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK-SAME: <{fused_activation_function = "NONE", kernel_type = #tfl}> ({ + // CHECK-NEXT: }) : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = #tfl} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -966,8 +969,8 @@ func.func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor, %arg18: tensor, %arg19: tensor, %arg20: none, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) - // CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl} : - // CHECK-SAME: (tensor, + // CHECK-SAME: <{cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl}> ({ + // CHECK-NEXT: }) : (tensor, // CHECK-SAME: none, tensor, tensor, tensor, none, tensor, tensor, tensor, // CHECK-SAME: none, tensor, tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, // CHECK-SAME: none, tensor, tensor, tensor) -> tensor @@ -1362,13 +1365,13 @@ func.func @testPadV2UnsupportedPaddings(tensor<*xf32>, tensor<6x3xi32>) -> tenso // ----- func.func @packQuantizedU8(%arg0: tensor<2x!quant.uniform>, %arg1: tensor<2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { - // CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} + // CHECK: "tfl.pack"(%arg0, %arg1) <{axis = 0 : i32, values_count = 2 : i32}> %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<2x2x!quant.uniform> func.return %0 : tensor<2x2x!quant.uniform> } func.func @packQuantizedI8(%arg0: tensor<2x!quant.uniform>, %arg1: tensor<2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { - // CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} + // CHECK: "tfl.pack"(%arg0, %arg1) <{axis = 0 : i32, values_count = 2 : i32}> %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<2x2x!quant.uniform> func.return %0 : tensor<2x2x!quant.uniform> } @@ -1376,7 +1379,7 @@ func.func @packQuantizedI8(%arg0: tensor<2x!quant.uniform>, %arg1: // ----- func.func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { - // CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} + // CHECK: "tfl.pack"(%arg0, %arg1) <{axis = 0 : i32, values_count = 2 : i32}> %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> func.return %0 : tensor<2x2xi32> } @@ -1384,7 +1387,7 @@ func.func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // ----- func.func @packUnranked(%arg0: tensor<2xi32>, %arg1: tensor<*xi32>) -> tensor<2x2xi32> { - // CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} + // CHECK: "tfl.pack"(%arg0, %arg1) <{axis = 0 : i32, values_count = 2 : i32}> %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<*xi32>) -> tensor<2x2xi32> func.return %0 : tensor<2x2xi32> } @@ -1392,7 +1395,7 @@ func.func @packUnranked(%arg0: tensor<2xi32>, %arg1: tensor<*xi32>) -> tensor<2x // ----- func.func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> { - // CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} + // CHECK: "tfl.pack"(%arg0, %arg1) <{axis = 2 : i32, values_count = 2 : i32}> %0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32> func.return %0 : tensor<1x4x2xi32> } @@ -1400,13 +1403,13 @@ func.func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tens // ----- func.func @packNegInputAxis2(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x2x4xi32> { - // CHECK: "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} + // CHECK: "tfl.pack"(%arg0, %arg1) <{axis = -2 : i32, values_count = 2 : i32}> %0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x2x4xi32> func.return %0 : tensor<1x2x4xi32> } func.func @packNegInputAxis3(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> { - // CHECK: "tfl.pack"(%arg0, %arg1) {axis = -3 : i32, values_count = 2 : i32} + // CHECK: "tfl.pack"(%arg0, %arg1) <{axis = -3 : i32, values_count = 2 : i32}> %0 = "tfl.pack"(%arg0, %arg1) {axis = -3 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32> func.return %0 : tensor<2x1x4xi32> } @@ -1414,7 +1417,7 @@ func.func @packNegInputAxis3(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> // ----- func.func @packInputUnranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { - // CHECK: "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} + // CHECK: "tfl.pack"(%arg0, %arg1) <{axis = -2 : i32, values_count = 2 : i32}> %0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> func.return %0 : tensor<*xi32> } @@ -1446,7 +1449,7 @@ func.func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // ----- func.func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { - // CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} + // CHECK: "tfl.unpack"(%arg0) <{axis = 1 : i32, num = 3 : i32}> %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) func.return %0#0 : tensor<2xi32> } @@ -1454,7 +1457,7 @@ func.func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { // ----- func.func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { - // CHECK: "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32} + // CHECK: "tfl.unpack"(%arg0) <{axis = -1 : i32, num = 3 : i32}> %0:3 = "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) func.return %0#0 : tensor<2xi32> } @@ -1462,7 +1465,7 @@ func.func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { // ----- func.func @unpack(%arg0: tensor<2x3xi32>) -> tensor<3xi32> { - // CHECK: "tfl.unpack"(%arg0) {axis = -2 : i32, num = 2 : i32} + // CHECK: "tfl.unpack"(%arg0) <{axis = -2 : i32, num = 2 : i32}> %0:2 = "tfl.unpack"(%arg0) {axis = -2 : i32, num = 2 : i32} : (tensor<2x3xi32>) -> (tensor<3xi32>, tensor<3xi32>) func.return %0#0 : tensor<3xi32> } @@ -1538,7 +1541,7 @@ func.func @unpack(%arg0: tensor) -> () { // CHECK-LABEL: testMean func.func @testMean(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> { - // CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = false} + // CHECK: "tfl.mean"(%arg0, %arg1) <{keep_dims = false}> %0 = "tfl.mean"(%arg0, %arg1) {keep_dims = false}: (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32> func.return %0 : tensor<1x2xf32> } @@ -1547,7 +1550,7 @@ func.func @testMean(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2 // CHECK-LABEL: testMean_true func.func @testMean_true(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> { - // CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = true} + // CHECK: "tfl.mean"(%arg0, %arg1) <{keep_dims = true}> %0 = "tfl.mean"(%arg0, %arg1) {keep_dims = true}: (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32> func.return %0 : tensor<1x2xf32> } @@ -1597,7 +1600,7 @@ func.func @testBatchMatmulHybridQuant(%arg0 : tensor<1x4x384x32xf32>, %arg1 : te // ----- func.func @testConcat(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<2x2xi32> { - // CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} + // CHECK: "tfl.concatenation"(%arg0, %arg1) <{axis = 0 : i32, fused_activation_function = "NONE"}> %0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> func.return %0 : tensor<2x2xi32> } @@ -1605,7 +1608,7 @@ func.func @testConcat(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor< // ----- func.func @testConcatQuantized(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { - // CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} + // CHECK: "tfl.concatenation"(%arg0, %arg1) <{axis = 0 : i32, fused_activation_function = "NONE"}> %0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> func.return %0 : tensor<2x2x!quant.uniform> } @@ -1692,7 +1695,7 @@ func.func @testConcatBenignDynamicDimSizeOperand(%arg0: tensor<1x?xi32>, %arg1: // CHECK-LABEL: testResizeBilinear func.func @testResizeBilinear(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) -> tensor { - // CHECK: "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false, half_pixel_centers = false} + // CHECK: "tfl.resize_bilinear"(%arg0, %arg1) <{align_corners = false, half_pixel_centers = false}> %0 = "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false, half_pixel_centers = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor func.return %0 : tensor } @@ -1709,7 +1712,7 @@ func.func @testResizeBilinearInvalidOutputType(%arg0 : tensor<1x100x100x3xf32>, // CHECK-LABEL: testStridedSlice func.func @testStridedSlice(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xf32> { - // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> func.return %0 : tensor<1x2x2x5xf32> } @@ -1750,7 +1753,7 @@ func.func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %a // CHECK-LABEL: testOneHot func.func @testOneHot(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<*xf32> { - // CHECK: "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xf32> + // CHECK: "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) <{axis = -1 : i32}> : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xf32> %0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xf32> func.return %0 : tensor<*xf32> } @@ -1784,7 +1787,7 @@ func.func @testArgMin(%arg0: tensor<3xi32>, %arg1: tensor) -> tensor { // CHECK-LABEL: testSpaceToDepth func.func @testSpaceToDepthF32(%arg0: tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xf32> { // CHECK: %[[ARG:.*]]: tensor<1x2x2x1xf32> - // CHECK: "tfl.space_to_depth"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xf32> + // CHECK: "tfl.space_to_depth"(%[[ARG]]) <{block_size = 2 : i32}> : (tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xf32> %0 = "tfl.space_to_depth"(%arg0) {block_size = 2: i32} : (tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xf32> func.return %0 : tensor<1x1x1x4xf32> } @@ -2046,7 +2049,7 @@ func.func @testWrongQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x1 // CHECK-LABEL: testSvdf func.func @testSvdf(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor { - // CHECK: "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "RELU", rank = 2 : i32} : (tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK: "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) <{fused_activation_function = "RELU", rank = 2 : i32}> : (tensor, tensor, tensor, tensor, tensor) -> tensor %0 = "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "RELU", rank = 2 : i32} : (tensor, tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -2056,7 +2059,7 @@ func.func @testSvdf(%arg0: tensor, %arg1: tensor, %arg2: tenso // CHECK-LABEL: testDepthToSpace func.func @testDepthToSpaceF32(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> { // CHECK: %[[ARG:.*]]: tensor<1x1x1x4xf32> - // CHECK: "tfl.depth_to_space"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> + // CHECK: "tfl.depth_to_space"(%[[ARG]]) <{block_size = 2 : i32}> : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> %0 = "tfl.depth_to_space"(%arg0) {block_size = 2: i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> func.return %0 : tensor<1x2x2x1xf32> } @@ -2619,7 +2622,7 @@ func.func @testTransposeConv(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32> // CHECK-LABEL: testTransposeConvWithOutputThatHasDynamicSizes func.func @testTransposeConvWithOutputThatHasDynamicSizes(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor { - // CHECK: %[[NONE:.*]] = "tfl.no_value"() {value} : () -> none + // CHECK: %[[NONE:.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK: "tfl.transpose_conv"(%arg0, %arg1, %arg2, %[[NONE]]) %cst = "tfl.no_value"() {value = unit} : () -> none %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 75c1a791eeca73..3c2c24baba8972 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -14,7 +14,7 @@ func.func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x %1 = "tfl.relu"(%0) : (tensor<256x32x32x16xf32>) -> tensor<256x32x32x16xf32> func.return %1 : tensor<256x32x32x16xf32> - // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x32x32x16xf32> + // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x32x32x16xf32> // CHECK: return %0 } @@ -24,7 +24,7 @@ func.func @fusedDepthwiseConv2dRelu6(%arg0: tensor<256x32x32x3xf32>, %arg1: tens %1 = "tfl.relu6"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32> func.return %1 : tensor<256x30x30x16xf32> - // CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> + // CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) <{depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}> : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> // CHECK: return %0 } @@ -34,7 +34,7 @@ func.func @fusedMaxPool2dRelu(%arg0: tensor<1x147x147x16xf32>) -> tensor<1x73x73 %1 = "tfl.relu"(%0) : (tensor<1x73x73x16xf32>) -> tensor<1x73x73x16xf32> func.return %1 : tensor<1x73x73x16xf32> - // CHECK: %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "RELU", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> + // CHECK: %0 = "tfl.max_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "RELU", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> // CHECK: return %0 } @@ -44,7 +44,7 @@ func.func @fusedAvgPool2dRelu1(%arg0: tensor<1x147x147x16xf32>) -> tensor<1x73x7 %1 = "tfl.relu_n1_to_1"(%0) : (tensor<1x73x73x16xf32>) -> tensor<1x73x73x16xf32> func.return %1 : tensor<1x73x73x16xf32> - // CHECK: %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "RELU_N1_TO_1", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> + // CHECK: %0 = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "RELU_N1_TO_1", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> // CHECK: return %0 } @@ -188,7 +188,7 @@ func.func @fuseMulIntoTransposeConvNoBias(%arg0: tensor<1x32x42x128xf32>) -> ten // CHECK-DAG: %[[SHAPE:.*]] = arith.constant dense<[1, 64, 84, 32]> : tensor<4xi32> // CHECK-DAG: %[[WEIGHTS:.*]] = arith.constant dense<1.500000e+00> : tensor<32x4x4x128xf32> - // CHECK-DAG: %[[BIAS:.*]] = "tfl.no_value"() {value} : () -> none + // CHECK-DAG: %[[BIAS:.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK: %[[RESULT:.*]] = "tfl.transpose_conv"(%[[SHAPE]], %[[WEIGHTS]], %arg0, %[[BIAS]]) // CHECK: return %[[RESULT]] } @@ -204,7 +204,7 @@ func.func @fuseAddIntoFollowingConv2d(%arg0: tensor<256x32x32x3xf32>) -> tensor< // CHECK-DAG: %[[w:.*]] = arith.constant dense<1.000000e+00> : tensor<16x3x3x3xf32> // CHECK-DAG: %[[b:.*]] = "tfl.pseudo_const"(){{.*}}dense<[4.150000e+01, 4.250000e+01, 4.350000e+01, 4.450000e+01, 4.550000e+01, 4.650000e+01, 4.750000e+01, 4.850000e+01, 4.950000e+01, 5.050000e+01, 5.150000e+01, 5.250000e+01, 5.350000e+01, 5.450000e+01, 5.550000e+01, 5.650000e+01]> : tensor<16xf32> -// CHECK-NEXT: %[[c:.*]] = "tfl.conv_2d"(%arg0, %[[w]], %[[b]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> +// CHECK-NEXT: %[[c:.*]] = "tfl.conv_2d"(%arg0, %[[w]], %[[b]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> // CHECK-NEXT: return %[[c]] : tensor<256x30x30x16xf32> } @@ -219,7 +219,7 @@ func.func @fuseSubIntoFollowingConv2d(%arg0: tensor<256x32x32x3xf32>) -> tensor< // CHECK-DAG: %[[w:.*]] = arith.constant dense<1.000000e+00> : tensor<16x3x3x3xf32> // CHECK-DAG: %[[b:.*]] = "tfl.pseudo_const"(){{.*}}dense<[-3.950000e+01, -3.850000e+01, -3.750000e+01, -3.650000e+01, -3.550000e+01, -3.450000e+01, -3.350000e+01, -3.250000e+01, -3.150000e+01, -3.050000e+01, -2.950000e+01, -2.850000e+01, -2.750000e+01, -2.650000e+01, -2.550000e+01, -2.450000e+01]> : tensor<16xf32> -// CHECK-NEXT: %[[c:.*]] = "tfl.conv_2d"(%arg0, %[[w]], %[[b]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> +// CHECK-NEXT: %[[c:.*]] = "tfl.conv_2d"(%arg0, %[[w]], %[[b]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> // CHECK-NEXT: return %[[c]] : tensor<256x30x30x16xf32> } @@ -279,7 +279,7 @@ func.func @fuseAddIntoFollowingDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>) - // CHECK-DAG: %[[w:.*]] = arith.constant dense<1.000000e+00> : tensor<3x3x3x16xf32> // CHECK-DAG: %[[b:.*]] = "tfl.pseudo_const"(){{.*}}dense<[4.150000e+01, 4.250000e+01, 4.350000e+01, 4.450000e+01, 4.550000e+01, 4.650000e+01, 4.750000e+01, 4.850000e+01, 4.950000e+01, 5.050000e+01, 5.150000e+01, 5.250000e+01, 5.350000e+01, 5.450000e+01, 5.550000e+01, 5.650000e+01]> : tensor<16xf32> -// CHECK-NEXT: %[[dc:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[w]], %[[b]]) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> +// CHECK-NEXT: %[[dc:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[w]], %[[b]]) <{depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32}> : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> // CHECK-NEXT: return %[[dc]] : tensor<256x30x30x16xf32> } @@ -322,7 +322,7 @@ func.func @fuseMulIntoConv2dWithQDQs(%arg0: tensor<256x32x32x3xf32>) -> tensor<2 // CHECK-DAG: %[[w:.*]] = "tfl.pseudo_const"(){{.*}}dense<3.000000e+00> : tensor<3x3x3x3xf32> // CHECK-DAG: %[[cst:.*]] = "tfl.pseudo_const"(){{.*}}dense<[1.500000e+00, 3.000000e+00, 4.500000e+00]> : tensor<3xf32> - // CHECK: %[[q:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<3x3x3x3x!quant.uniform:f32:0, {1.500000e+00,3.000000e+00,4.500000e+00}>>} + // CHECK: %[[q:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<3x3x3x3x!quant.uniform:f32:0, {1.500000e+00,3.000000e+00,4.500000e+00}>>}> // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq]], %[[cst]]) // CHECK: return %[[conv]] : tensor<256x8x7x3xf32> @@ -341,7 +341,7 @@ func.func @fuseMulIntoFullyConnectedWithOptionalAttribute(%arg0: tensor<4x2xf32> // CHECK-DAG: %[[CONSTANT:.*]] = "tfl.pseudo_const"(){{.*}}dense<{{\[\[}}1.000000e+00, 2.000000e+00], [6.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> // CHECK-DAG: %[[CONSTANT0:.*]] = "tfl.pseudo_const"(){{.*}}dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> -// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {asymmetric_quantize_inputs = true, +// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) <{asymmetric_quantize_inputs = true, } // CHECK-LABEL: @fuseMulIntoFullyConnected @@ -357,7 +357,7 @@ func.func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> // CHECK-DAG: %[[CONSTANT:.*]] = "tfl.pseudo_const"(){{.*}}dense<{{\[\[}}1.000000e+00, 2.000000e+00], [6.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> // CHECK-DAG: %[[CONSTANT0:.*]] = "tfl.pseudo_const"(){{.*}}dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> -// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} +// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) <{fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}> // CHECK: return %[[RES]] : tensor<4x2xf32> } @@ -372,8 +372,8 @@ func.func @DontFuseMulIntoFullyConnectedForLargeFilter(%arg0: tensor<128x256000x func.return %1 : tensor<128x1024xf32> -// CHECK: %[[a:.*]] = "tfl.fully_connected"(%arg0, %cst_0, %cst_1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} -// CHECK: %[[b:.*]] = tfl.mul(%[[a]], %cst) {fused_activation_function = "RELU6"} +// CHECK: %[[a:.*]] = "tfl.fully_connected"(%arg0, %cst_0, %cst_1) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> +// CHECK: %[[b:.*]] = tfl.mul(%[[a]], %cst) <{fused_activation_function = "RELU6"}> } @@ -393,9 +393,9 @@ func.func @skipFuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> (tensor<1x8x // CHECK: %cst_0 = arith.constant dense<2.000000e+00> : tensor<2xf32> // CHECK: %cst_1 = arith.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> // CHECK: %cst_2 = arith.constant dense<[1, 8]> : tensor<2xi32> - // CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + // CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> // CHECK: %1 = "tfl.reshape"(%0, %cst_2) : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<1x8xf32> - // CHECK: %2 = tfl.mul(%0, %cst_1) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + // CHECK: %2 = tfl.mul(%0, %cst_1) <{fused_activation_function = "RELU6"}> : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> // CHECK: return %1, %2 : tensor<1x8xf32>, tensor<4x2xf32> } @@ -414,7 +414,7 @@ func.func @fuseAddIntoFollowingFullyConnectedWithQDQs(%arg0: tensor<4x2xf32>) -> // CHECK-DAG: %[[b:.*]] = "tfl.pseudo_const"(){{.*}}dense<[6.500000e+00, 1.250000e+01]> : tensor<2xf32> // CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%[[w]]) // CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) -// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq]], %[[b]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> +// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq]], %[[b]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> // CHECK-NEXT: return %[[fc]] : tensor<4x2xf32> } @@ -429,7 +429,7 @@ func.func @fuseAddIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>) -> tensor< // CHECK-DAG: %[[w:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32> // CHECK-DAG: %[[b:.*]] = "tfl.pseudo_const"(){{.*}}dense<[6.500000e+00, 1.250000e+01]> : tensor<2xf32> -// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[w]], %[[b]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> +// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[w]], %[[b]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> // CHECK-NEXT: return %[[fc]] : tensor<4x2xf32> } @@ -456,7 +456,7 @@ func.func @fuseMulIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>) -> tensor< // CHECK-DAG: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<2xf32> // CHECK-DAG: %[[w:.*]] = "tfl.pseudo_const"(){{.*}}dense<{{\[}}[1.500000e+00, 3.000000e+00], [4.500000e+00, 6.000000e+00]]> : tensor<2x2xf32> -// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[w]], %[[b]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> +// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[w]], %[[b]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> // CHECK-NEXT: return %[[fc]] : tensor<4x2xf32> } @@ -472,7 +472,7 @@ func.func @fuseMulIntoFullyConnectedBroadcast(%arg0: tensor<1x3xf32>) -> tensor< // CHECK-DAG: %[[CONSTANT:.*]] = "tfl.pseudo_const"(){{.*}}dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> // CHECK-DAG: %[[CONSTANT0:.*]] = "tfl.pseudo_const"(){{.*}}dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> -// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} +// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) <{fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}> // CHECK: return %[[RES]] : tensor<1x2xf32> } @@ -487,7 +487,7 @@ func.func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) func.return %1 : tensor<4x2xf32> // CHECK-DAG: %[[CONSTANT:.*]] = "tfl.pseudo_const"(){{.*}}dense<{{\[\[}}1.000000e+00, 2.000000e+00], [6.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> -// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32> +// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %arg1) <{fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32> // CHECK: return %[[RES]] : tensor<4x2xf32> } @@ -504,7 +504,7 @@ func.func @fuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor< // CHECK-DAG: %cst = arith.constant dense<{{\[\[\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00], [5.000000e+00, 1.200000e+01]], {{\[\[}}7.000000e+00, 1.600000e+01], [9.000000e+00, 2.000000e+01], [1.100000e+01, 2.400000e+01]], {{\[\[}}1.300000e+01, 2.800000e+01], [1.500000e+01, 3.200000e+01], [1.700000e+01, 3.600000e+01]]]]> : tensor<1x3x3x2xf32> // CHECK-DAG: %cst_0 = arith.constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> -// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> +// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0) <{depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> // CHECK: return %0 } @@ -521,7 +521,7 @@ func.func @fuse4DMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tenso // CHECK-DAG: %cst = arith.constant dense<{{\[\[\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00], [5.000000e+00, 1.200000e+01]], {{\[\[}}7.000000e+00, 1.600000e+01], [9.000000e+00, 2.000000e+01], [1.100000e+01, 2.400000e+01]], {{\[\[}}1.300000e+01, 2.800000e+01], [1.500000e+01, 3.200000e+01], [1.700000e+01, 3.600000e+01]]]]> : tensor<1x3x3x2xf32> // CHECK-DAG: %cst_0 = arith.constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> -// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> +// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0) <{depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> // CHECK: return %0 } @@ -648,13 +648,13 @@ func.func @FuseFullyConnectedMultiUseAddBroadcastedNagative(%arg0: tensor<1x40x3 %4 = "tfl.mul"(%2, %cst1) {fused_activation_function = "NONE"} : (tensor<1x40x4xf32>, tensor<1x1x4xf32>) -> tensor<1x40x4xf32> func.return %1, %3, %4 : tensor<1x40x4xf32>, tensor<1x40x4xf32>, tensor<1x40x4xf32> - // CHECK: %0 = "tfl.no_value"() {value} : () -> none + // CHECK: %0 = "tfl.no_value"() <{value}> : () -> none // CHECK: %cst = arith.constant dense<{{\[\[\[}}2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]]]> : tensor<1x1x4xf32> - // CHECK: %1 = "tfl.fully_connected"(%arg0, %arg1, %0) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x40x37xf32>, tensor<4x37xf32>, none) -> tensor<1x40x4xf32> - // CHECK: %2 = tfl.add(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x40x4xf32>, tensor<1x1x4xf32>) -> tensor<1x40x4xf32> - // CHECK: %3 = "tfl.fully_connected"(%arg0, %arg1, %0) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x40x37xf32>, tensor<4x37xf32>, none) -> tensor<1x40x4xf32> - // CHECK: %4 = tfl.add(%3, %cst) {fused_activation_function = "NONE"} : (tensor<1x40x4xf32>, tensor<1x1x4xf32>) -> tensor<1x40x4xf32> - // CHECK: %5 = tfl.mul(%3, %cst) {fused_activation_function = "NONE"} : (tensor<1x40x4xf32>, tensor<1x1x4xf32>) -> tensor<1x40x4xf32> + // CHECK: %1 = "tfl.fully_connected"(%arg0, %arg1, %0) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x40x37xf32>, tensor<4x37xf32>, none) -> tensor<1x40x4xf32> + // CHECK: %2 = tfl.add(%1, %cst) <{fused_activation_function = "NONE"}> : (tensor<1x40x4xf32>, tensor<1x1x4xf32>) -> tensor<1x40x4xf32> + // CHECK: %3 = "tfl.fully_connected"(%arg0, %arg1, %0) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x40x37xf32>, tensor<4x37xf32>, none) -> tensor<1x40x4xf32> + // CHECK: %4 = tfl.add(%3, %cst) <{fused_activation_function = "NONE"}> : (tensor<1x40x4xf32>, tensor<1x1x4xf32>) -> tensor<1x40x4xf32> + // CHECK: %5 = tfl.mul(%3, %cst) <{fused_activation_function = "NONE"}> : (tensor<1x40x4xf32>, tensor<1x1x4xf32>) -> tensor<1x40x4xf32> // CHECK: return %2, %4, %5 : tensor<1x40x4xf32>, tensor<1x40x4xf32>, tensor<1x40x4xf32> } @@ -670,7 +670,7 @@ func.func @FuseFullyConnectedBroadcastedBiasAddWithQDQs(%arg0: tensor<40x37xf32> // CHECK: %[[cst:.*]] = arith.constant dense<2.000000e+00> : tensor<40xf32> // CHECK: %[[q:.*]] = "tfl.quantize" - // CHECK-SAME: {qtype = tensor<40x!quant.uniform>} : (tensor<40xf32>) -> tensor<40x!quant.uniform> + // CHECK-SAME: <{qtype = tensor<40x!quant.uniform>}> : (tensor<40xf32>) -> tensor<40x!quant.uniform> // CHECK: %[[dq:.*]] = "tfl.dequantize" // CHECK-SAME: (tensor<40x!quant.uniform>) -> tensor<40xf32> // CHECK: %[[fc:.*]] = "tfl.fully_connected" @@ -735,7 +735,7 @@ func.func @FuseFullyConnectedAddNoBiasWithUnfusableRhs(%arg0: tensor<4x37xf32>, func.return %1 : tensor<4x4xf32> - // CHECK-DAG: %[[unit:.*]] = "tfl.no_value"() {value} : () -> none + // CHECK-DAG: %[[unit:.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK-DAG: %[[filter:.*]] = arith.constant dense<{{.*}}> : tensor<4x4xf32> // CHECK: %[[fc_result:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[unit]]) // CHECK: %[[add_result:.*]] = tfl.add %[[fc_result]], %[[filter]] @@ -769,7 +769,7 @@ func.func @FuseReshapeAroundBMMLHS(%arg0: tensor<6x5x1024xf32>) -> tensor<6x5x81 %2 = "tfl.reshape"(%1, %cst_0) : (tensor<30x8192xf32>, tensor<3xi32>) -> tensor<6x5x8192xf32> return %2 : tensor<6x5x8192xf32> // CHECK: %cst = arith.constant dense_resource<__elided__> : tensor<1024x8192xf32> - // CHECK: %0 = "tfl.batch_matmul"(%arg0, %cst) {adj_x = false, adj_y = false} : (tensor<6x5x1024xf32>, tensor<1024x8192xf32>) -> tensor<6x5x8192xf32> + // CHECK: %0 = "tfl.batch_matmul"(%arg0, %cst) <{adj_x = false, adj_y = false}> : (tensor<6x5x1024xf32>, tensor<1024x8192xf32>) -> tensor<6x5x8192xf32> // CHECK: return %0 : tensor<6x5x8192xf32> } @@ -784,7 +784,7 @@ func.func @FuseReshapeAroundBMMLHSNegative(%arg0: tensor<1x64xf32>, %arg1: tenso // CHECK: %cst = arith.constant dense<[1, 1024]> : tensor<2xi32> // CHECK: %cst_0 = arith.constant dense<[1, 1, 64]> : tensor<3xi32> // CHECK: %0 = "tfl.reshape"(%arg0, %cst_0) : (tensor<1x64xf32>, tensor<3xi32>) -> tensor<1x1x64xf32> - // CHECK: %1 = "tfl.batch_matmul"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x1x64xf32>, tensor<1x64x1024xf32>) -> tensor<1x1x1024xf32> + // CHECK: %1 = "tfl.batch_matmul"(%0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<1x1x64xf32>, tensor<1x64x1024xf32>) -> tensor<1x1x1024xf32> // CHECK: %2 = "tfl.reshape"(%1, %cst) : (tensor<1x1x1024xf32>, tensor<2xi32>) -> tensor<1x1024xf32> // CHECK: return %2 : tensor<1x1024xf32> } @@ -800,7 +800,7 @@ func.func @FuseReshapeAroundBMMNagativeTest(%arg0: tensor<5x4x1x1024xf32>, %arg1 // CHECK: %cst = arith.constant dense_resource<__elided__> : tensor<3xi32> // CHECK: %cst_0 = arith.constant dense_resource<__elided__> : tensor<4xi32> // CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<5x4x1x1024xf32>, tensor<3xi32>) -> tensor<5x4x1024xf32> - // CHECK: %1 = "tfl.batch_matmul"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<5x4x1024xf32>, tensor<5x1024x8192xf32>) -> tensor<5x4x8192xf32> + // CHECK: %1 = "tfl.batch_matmul"(%0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<5x4x1024xf32>, tensor<5x1024x8192xf32>) -> tensor<5x4x8192xf32> // CHECK: %2 = "tfl.reshape"(%1, %cst_0) : (tensor<5x4x8192xf32>, tensor<4xi32>) -> tensor<5x4x1x8192xf32> // CHECK: return %2 : tensor<5x4x1x8192xf32> } @@ -819,8 +819,8 @@ func.func @FuseReshapeAroundBMMNagativeTest2(%arg0: tensor<2x1536xf32>) -> tenso // CHECK: %cst = arith.constant dense_resource<__elided__> : tensor<3xi32> // CHECK: %cst_0 = arith.constant dense_resource<__elided__> : tensor<2xi32> // CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<2x1536xf32>, tensor<3xi32>) -> tensor<2x12x128xf32> - // CHECK: %1 = "tfl.pseudo_qconst"() {qtype = tensor<128x64x!quant.uniform>, value = dense<9> : tensor<128x64xi8>} : () -> tensor<128x64x!quant.uniform> - // CHECK: %2 = "tfl.batch_matmul"(%0, %1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = true} : (tensor<2x12x128xf32>, tensor<128x64x!quant.uniform>) -> tensor<2x12x64xf32> + // CHECK: %1 = "tfl.pseudo_qconst"() <{qtype = tensor<128x64x!quant.uniform>, value = dense<9> : tensor<128x64xi8>}> : () -> tensor<128x64x!quant.uniform> + // CHECK: %2 = "tfl.batch_matmul"(%0, %1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = true}> : (tensor<2x12x128xf32>, tensor<128x64x!quant.uniform>) -> tensor<2x12x64xf32> // CHECK: %3 = "tfl.reshape"(%2, %cst_0) : (tensor<2x12x64xf32>, tensor<2xi32>) -> tensor<2x768xf32> // CHECK: return %3 : tensor<2x768xf32> } @@ -835,7 +835,7 @@ func.func @FuseReshapeAroundBMMRHS(%arg0: tensor<1x3x6x5x1024xf32>) -> tensor<1x %2 = "tfl.reshape"(%1, %cst_0) : (tensor<1x90x8192xf32>, tensor<5xi32>) -> tensor<1x3x6x5x8192xf32> return %2 : tensor<1x3x6x5x8192xf32> // CHECK: %cst = arith.constant dense_resource<__elided__> : tensor<1x1024x8192xf32> - // CHECK: %0 = "tfl.batch_matmul"(%arg0, %cst) {adj_x = false, adj_y = false} : (tensor<1x3x6x5x1024xf32>, tensor<1x1024x8192xf32>) -> tensor<1x3x6x5x8192xf32> + // CHECK: %0 = "tfl.batch_matmul"(%arg0, %cst) <{adj_x = false, adj_y = false}> : (tensor<1x3x6x5x1024xf32>, tensor<1x1024x8192xf32>) -> tensor<1x3x6x5x8192xf32> // CHECK: return %0 : tensor<1x3x6x5x8192xf32> } @@ -845,7 +845,7 @@ func.func @FuseTransposeIntoBMM_LHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tens %32 = "tfl.transpose"(%arg1, %cst_1) : (tensor<1x1440x256xf32>, tensor<3xi32>) -> tensor<1x256x1440xf32> %33 = "tfl.batch_matmul"(%32, %arg0) {adj_x = false, adj_y = false} : (tensor<1x256x1440xf32>, tensor<1x4x1440x256xf32>) -> tensor<1x4x256x256xf32> return %33 : tensor<1x4x256x256xf32> - // CHECK: %0 = "tfl.batch_matmul"(%arg1, %arg0) {adj_x = true, adj_y = false} : (tensor<1x1440x256xf32>, tensor<1x4x1440x256xf32>) -> tensor<1x4x256x256xf32> + // CHECK: %0 = "tfl.batch_matmul"(%arg1, %arg0) <{adj_x = true, adj_y = false}> : (tensor<1x1440x256xf32>, tensor<1x4x1440x256xf32>) -> tensor<1x4x256x256xf32> // CHECK: return %0 : tensor<1x4x256x256xf32> } @@ -902,8 +902,8 @@ func.func @RetainRedundantReshapeUseInNonBinaryOp(%arg0: tensor<128xf32>, %arg1: // CHECK-DAG: %cst = arith.constant dense<0> : tensor<1xi32> // CHECK-DAG: %cst_0 = arith.constant dense<[1, 1, 1, 128]> : tensor<4xi32> // CHECK: %0 = "tfl.reshape"(%arg0, %cst_0) : (tensor<128xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> - // CHECK: %1 = tfl.mul(%0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x1x1x128xf32>, tensor<1x512x512x128xf32>) -> tensor<1x512x512x128xf32> - // CHECK: %2 = "tfl.reduce_max"(%0, %cst) {keep_dims = false} : (tensor<1x1x1x128xf32>, tensor<1xi32>) -> tensor<128xf32> + // CHECK: %1 = tfl.mul(%0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x1x1x128xf32>, tensor<1x512x512x128xf32>) -> tensor<1x512x512x128xf32> + // CHECK: %2 = "tfl.reduce_max"(%0, %cst) <{keep_dims = false}> : (tensor<1x1x1x128xf32>, tensor<1xi32>) -> tensor<128xf32> // CHECK: return %1, %2 } @@ -963,10 +963,10 @@ func.func @FuseFullyConnectedReshapeAddConstWithOptionalAttribute(%arg0: tensor< func.return %3 : tensor<40x40xf32> // CHECK-DAG: %[[cst:.*]] = arith.constant dense<5.000000e+00> : tensor<40x40xf32> - // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {asymmetric_quantize_inputs = true, + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) <{asymmetric_quantize_inputs = true, // FOLD: %[[cst:.*]] = arith.constant dense<5.000000e+00> : tensor<40x40xf32> - // FOLD: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {asymmetric_quantize_inputs = true, + // FOLD: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) <{asymmetric_quantize_inputs = true, } // CHECK-LABEL: @FuseFullyConnectedReshapeAddConstWithActivation @@ -985,13 +985,13 @@ func.func @FuseFullyConnectedReshapeAddConstWithActivation(%arg0: tensor<40x37xf func.return %3 : tensor<40x40xf32> // CHECK-DAG: %[[cst:.*]] = arith.constant dense<5.000000e+00> : tensor<40x40xf32> - // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) <{fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}> // CHECK: %[[rs1:.*]] = "tfl.reshape"(%[[fc]] // CHECK: %[[rs2:.*]] = "tfl.reshape"(%[[rs1]] // CHECK: return %[[rs2]] // FOLD: %[[cst:.*]] = arith.constant dense<5.000000e+00> : tensor<40x40xf32> - // FOLD: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} + // FOLD: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) <{fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}> // FOLD: return %[[fc]] } @@ -1008,7 +1008,7 @@ func.func @FuseFullyConnectedReshapeAdd2DConst(%arg0: tensor<40x37xf32>, %arg1: func.return %2 : tensor<1x40x4x10xf32> // CHECK-DAG: %[[cst:.*]] = arith.constant dense<2.000000e+00> : tensor<40xf32> - // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> // CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]] // CHECK: return %[[rs]] } @@ -1026,7 +1026,7 @@ func.func @FuseFCReshapeAdd2DConst2(%arg0: tensor<40x37xf32>, %arg1: tensor<40x3 func.return %2 : tensor<1x40x4x10xf32> // CHECK-DAG: %[[cst:.*]] = arith.constant dense<2.000000e+00> : tensor<40xf32> - // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> // CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]] // CHECK: return %[[rs]] } @@ -1044,7 +1044,7 @@ func.func @FuseFullyConnectedReshapeAdd2DConstWithActivation(%arg0: tensor<40x37 func.return %2 : tensor<1x40x4x10xf32> // CHECK-DAG: %[[cst:.*]] = arith.constant dense<2.000000e+00> : tensor<40xf32> - // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) <{fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}> // CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]] // CHECK: return %[[rs]] } @@ -1062,7 +1062,7 @@ func.func @FuseFCReshapeAdd2DConstWithActvtn2(%arg0: tensor<40x37xf32>, %arg1: t func.return %2 : tensor<1x40x4x10xf32> // CHECK-DAG: %[[cst:.*]] = arith.constant dense<2.000000e+00> : tensor<40xf32> - // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) <{fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}> // CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]] // CHECK: return %[[rs]] } @@ -1328,7 +1328,7 @@ func.func @HardSwishPatternFail(%arg0: tensor<1xf32>) -> tensor<1xf32> { %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> %2 = "tfl.mul"(%1, %six) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor) -> tensor<1xf32> func.return %2: tensor<1xf32> - // CHECK: %0 = tfl.sub(%arg0, %cst) {fused_activation_function = "RELU6"} : (tensor<1xf32>, tensor) -> tensor<1xf32> + // CHECK: %0 = tfl.sub(%arg0, %cst) <{fused_activation_function = "RELU6"}> : (tensor<1xf32>, tensor) -> tensor<1xf32> } // CHECK-LABEL: @L2NormalizePattern @@ -1339,7 +1339,7 @@ func.func @L2NormalizePattern(%arg0: tensor<2xf32>) -> tensor<2xf32> { %2 = "tfl.rsqrt"(%1) : (tensor) -> tensor %3 = "tfl.mul"(%arg0, %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor) -> tensor<2xf32> func.return %3: tensor<2xf32> - // CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) {fused_activation_function = "NONE"} : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) <{fused_activation_function = "NONE"}> : (tensor<2xf32>) -> tensor<2xf32> // CHECK: return %[[RES]] } @@ -1351,7 +1351,7 @@ func.func @L2NormalizePattern1(%arg0: tensor<2xf32>) -> tensor<2xf32> { %2 = "tfl.sqrt"(%1) : (tensor) -> tensor %3 = "tfl.div"(%arg0, %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor) -> tensor<2xf32> func.return %3: tensor<2xf32> - // CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) {fused_activation_function = "NONE"} : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) <{fused_activation_function = "NONE"}> : (tensor<2xf32>) -> tensor<2xf32> // CHECK: return %[[RES]] } @@ -1365,7 +1365,7 @@ func.func @L2NormalizePattern2(%arg0: tensor<2xf32>) -> tensor<2xf32> { %3 = "tfl.rsqrt"(%2) : (tensor<1xf32>) -> tensor<1xf32> %4 = "tfl.mul"(%arg0, %3) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32> func.return %4: tensor<2xf32> - // CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) {fused_activation_function = "NONE"} : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) <{fused_activation_function = "NONE"}> : (tensor<2xf32>) -> tensor<2xf32> // CHECK: return %[[RES]] } @@ -1379,7 +1379,7 @@ func.func @L2NormalizePattern3(%arg0: tensor<2xf32>) -> tensor<2xf32> { %3 = "tfl.sqrt"(%2) : (tensor<1xf32>) -> tensor<1xf32> %4 = "tfl.div"(%arg0, %3) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32> func.return %4: tensor<2xf32> - // CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) {fused_activation_function = "NONE"} : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) <{fused_activation_function = "NONE"}> : (tensor<2xf32>) -> tensor<2xf32> // CHECK: return %[[RES]] } @@ -1393,7 +1393,7 @@ func.func @L2NormalizePattern4(%arg0: tensor<2xf32>) -> tensor<2xf32> { %3 = "tfl.sqrt"(%2) : (tensor<1xf32>) -> tensor<1xf32> %4 = "tfl.div"(%arg0, %3) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32> func.return %4: tensor<2xf32> - // CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) {fused_activation_function = "NONE"} : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) <{fused_activation_function = "NONE"}> : (tensor<2xf32>) -> tensor<2xf32> // CHECK: return %[[RES]] } @@ -1407,7 +1407,7 @@ func.func @L2NormalizePattern5(%arg0: tensor<2xf32>) -> tensor<2xf32> { %3 = "tfl.sqrt"(%2) : (tensor<1xf32>) -> tensor<1xf32> %4 = "tfl.div"(%arg0, %3) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32> func.return %4: tensor<2xf32> - // CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) {fused_activation_function = "NONE"} : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) <{fused_activation_function = "NONE"}> : (tensor<2xf32>) -> tensor<2xf32> // CHECK: return %[[RES]] } @@ -1420,7 +1420,7 @@ func.func @InvalidL2NormalizePattern(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) %2 = "tfl.sqrt"(%1) : (tensor) -> tensor %3 = "tfl.div"(%arg1, %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor) -> tensor<2xf32> func.return %3: tensor<2xf32> - // CHECK: %3 = tfl.div([[INPUT:%.*]], %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor) -> tensor<2xf32> + // CHECK: %3 = tfl.div([[INPUT:%.*]], %2) <{fused_activation_function = "NONE"}> : (tensor<2xf32>, tensor) -> tensor<2xf32> // CHECK: return %3 } @@ -1435,7 +1435,7 @@ func.func @InvalidL2NormalizePattern2(%arg0: tensor<2xf32>, %arg1: tensor<2xf32> %3 = "tfl.sqrt"(%2) : (tensor<1xf32>) -> tensor<1xf32> %4 = "tfl.div"(%arg0, %3) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32> func.return %4 : tensor<2xf32> - // CHECK: %[[RES:[0-9].*]] = tfl.div([[INPUT:%.*]], %3) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32> + // CHECK: %[[RES:[0-9].*]] = tfl.div([[INPUT:%.*]], %3) <{fused_activation_function = "NONE"}> : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32> // CHECK: return %[[RES]] } @@ -1448,7 +1448,7 @@ func.func @InvalidL2NormalizePattern3(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> %2 = "tfl.sqrt"(%1) : (tensor) -> tensor %3 = "tfl.div"(%arg0, %2) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> func.return %3: tensor<2x2xf32> - // CHECK: %[[RES:[0-9].*]] = tfl.div([[INPUT:%.*]], %2) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + // CHECK: %[[RES:[0-9].*]] = tfl.div([[INPUT:%.*]], %2) <{fused_activation_function = "NONE"}> : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> // CHECK: return %[[RES]] } @@ -1463,7 +1463,7 @@ func.func @fuseDivIntoConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x28x23x2 func.return %1 : tensor<1x28x23x2xf32> // CHECK-DAG: %[[cst:.*]] = arith.constant dense<{{\[\[\[\[}}1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]], {{\[\[}}5.000000e+00, 6.000000e+00], [7.000000e+00, 8.000000e+00]]], {{\[\[\[}}4.500000e+00, 5.000000e+00], [5.500000e+00, 6.000000e+00]], {{\[\[}}6.500000e+00, 7.000000e+00], [7.500000e+00, 8.000000e+00]]]]> : tensor<2x2x2x2xf32> // CHECK-DAG: %[[cst:.*]] = arith.constant dense<[1.000000e+00, 5.000000e-01]> : tensor<2xf32> - // CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %cst, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<2x2x2x2xf32>, tensor<2xf32>) -> tensor<1x28x23x2xf32> + // CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %cst, %cst_0) <{dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}> : (tensor<1x112x112x2xf32>, tensor<2x2x2x2xf32>, tensor<2xf32>) -> tensor<1x28x23x2xf32> // CHECK: return %[[RES]] } @@ -1478,7 +1478,7 @@ func.func @fuseDivIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor< func.return %1 : tensor<1x112x112x2xf32> // CHECK-DAG: %[[cst:.*]] = arith.constant dense<{{\[\[\[\[}}1.000000e+00, 1.000000e+00], [3.000000e+00, 2.000000e+00]], {{\[\[}}5.000000e+00, 3.000000e+00], [7.000000e+00, 4.000000e+00]]], {{\[\[\[}}9.000000e+00, 5.000000e+00], [1.100000e+01, 6.000000e+00]], {{\[\[}}1.300000e+01, 7.000000e+00], [1.500000e+01, 8.000000e+00]]]]> : tensor<2x2x2x2xf32> // CHECK-DAG: %[[cst:.*]] = arith.constant dense<[1.000000e+00, 5.000000e-01]> : tensor<2xf32> - // CHECK: %[[RES:[0-9].*]] = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0) {depth_multiplier = 1 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<2x2x2x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> + // CHECK: %[[RES:[0-9].*]] = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0) <{depth_multiplier = 1 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}> : (tensor<1x112x112x2xf32>, tensor<2x2x2x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> // CHECK: return %[[RES]] } @@ -1493,7 +1493,7 @@ func.func @fuseDivIntoConv2d_Scalar(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x func.return %1 : tensor<1x28x23x1xf32> // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<{{\[\[\[\[}}5.000000e-01, 1.000000e+00], [1.500000e+00, 2.000000e+00]], {{\[\[}}2.500000e+00, 3.000000e+00], [3.500000e+00, 4.000000e+00]]]]> : tensor<1x2x2x2xf32> // CHECK-DAG: %[[CST2:.*]] = arith.constant dense<5.000000e-01> : tensor<2xf32> - // CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x28x23x1xf32> + // CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) <{dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}> : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x28x23x1xf32> // CHECK: return %[[RES]] } @@ -1508,7 +1508,7 @@ func.func @fuseMulIntoConv2d_Scalar(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x func.return %1 : tensor<1x28x23x1xf32> // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<{{\[\[\[\[}}2.000000e+00, 4.000000e+00], [6.000000e+00, 8.000000e+00]], {{\[\[}}1.000000e+01, 1.200000e+01], [1.400000e+01, 1.600000e+01]]]]> : tensor<1x2x2x2xf32> // CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<1xf32> - // CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<1xf32>) -> tensor<1x28x23x1xf32> + // CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) <{dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}> : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<1xf32>) -> tensor<1x28x23x1xf32> // CHECK: return %[[RES]] } @@ -1537,9 +1537,9 @@ func.func @fuseTileWithBinaryOp1(%arg0: tensor<1x1xf32>, %arg1: tensor<1x128xf32 func.return %3 : tensor<1x128xf32> // CHECK-DAG: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor - // CHECK: %[[ADD:[0-9].*]] = tfl.add(%arg0, %[[cst]]) {fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor) -> tensor<1x1xf32> + // CHECK: %[[ADD:[0-9].*]] = tfl.add(%arg0, %[[cst]]) <{fused_activation_function = "NONE"}> : (tensor<1x1xf32>, tensor) -> tensor<1x1xf32> // CHECK: %[[SQRT:[0-9].*]] = "tfl.sqrt"(%[[ADD]]) : (tensor<1x1xf32>) -> tensor<1x1xf32> - // CHECK: %[[RES:[0-9].*]] = tfl.div(%[[SQRT]], %arg1) {fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<1x128xf32>) -> tensor<1x128xf32> + // CHECK: %[[RES:[0-9].*]] = tfl.div(%[[SQRT]], %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x1xf32>, tensor<1x128xf32>) -> tensor<1x128xf32> // CHECK: return %[[RES]] } @@ -1640,7 +1640,7 @@ func.func @convertTrivialTransposeToReshape(%arg0: tensor<6x6x256x1xf32>) -> ten %0 = "tfl.transpose"(%arg0, %cst) : (tensor<6x6x256x1xf32>, tensor<4xi32>) -> tensor<1x6x6x256xf32> func.return %0 : tensor<1x6x6x256xf32> - // CHECK-DAG: [[CONST:.*]] = "tfl.pseudo_const"(){{.*}}dense<[1, 6, 6, 256]> : tensor<4xi32> + // CHECK-DAG: [[CONST:.*]] = arith.constant {{.*}}dense<[1, 6, 6, 256]> : tensor<4xi32> // CHECK: %[[RESULT:.*]] = "tfl.reshape"(%arg0, %[[CONST:.*]]) : (tensor<6x6x256x1xf32>, tensor<4xi32>) -> tensor<1x6x6x256xf32> // CHECK: return %[[RESULT]] } @@ -1797,8 +1797,8 @@ func.func @FusingbiasAdd(%arg0: tensor<1x10x10x32xf32>, %arg1: tensor<32xf32>) - func.return %2 : tensor<1x10x10x32xf32> // Fusing-LABEL: FusingbiasAdd -// Fusing: %[[add:[0-9].*]] = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32> -// Fusing: %[[add1:[0-9].*]] = tfl.add(%[[add]], %arg1) {fused_activation_function = "RELU6"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32> +// Fusing: %[[add:[0-9].*]] = tfl.add(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32> +// Fusing: %[[add1:[0-9].*]] = tfl.add(%[[add]], %arg1) <{fused_activation_function = "RELU6"}> : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32> } func.func @FusingdivRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { @@ -2017,7 +2017,7 @@ func.func @FoldSumKeepDim(%arg0: tensor<8x128xf32>) -> tensor<8x1xf32> { func.return %1 : tensor<8x1xf32> // CHECK-LABEL: FoldSumKeepDim -// CHECK: %[[RESULT:.*]] = "tfl.sum"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32> +// CHECK: %[[RESULT:.*]] = "tfl.sum"(%arg0, %cst) <{keep_dims = true}> : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32> // CHECK: return %[[RESULT]] : tensor<8x1xf32> } @@ -2029,7 +2029,7 @@ func.func @FoldReduceMinKeepDim(%arg0: tensor<8x128xf32>) -> tensor<1x128xf32> { func.return %1 : tensor<1x128xf32> // CHECK-LABEL: FoldReduceMinKeepDim -// CHECK: %[[RESULT:.*]] = "tfl.reduce_min"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<1x128xf32> +// CHECK: %[[RESULT:.*]] = "tfl.reduce_min"(%arg0, %cst) <{keep_dims = true}> : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<1x128xf32> // CHECK: return %[[RESULT]] : tensor<1x128xf32> } @@ -2041,7 +2041,7 @@ func.func @FoldReduceMaxKeepDim(%arg0: tensor<8x128xf32>) -> tensor<1x128xf32> { func.return %1 : tensor<1x128xf32> // CHECK-LABEL: FoldReduceMaxKeepDim -// CHECK: %[[RESULT:.*]] = "tfl.reduce_max"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<1x128xf32> +// CHECK: %[[RESULT:.*]] = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = true}> : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<1x128xf32> // CHECK: return %[[RESULT]] : tensor<1x128xf32> } @@ -2053,7 +2053,7 @@ func.func @FoldReduceProdKeepDim(%arg0: tensor<8x128xf32>) -> tensor<1x1xf32> { func.return %1 : tensor<1x1xf32> // CHECK-LABEL: FoldReduceProdKeepDim -// CHECK: %[[RESULT:.*]] = "tfl.reduce_prod"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<2xi32>) -> tensor<1x1xf32> +// CHECK: %[[RESULT:.*]] = "tfl.reduce_prod"(%arg0, %cst) <{keep_dims = true}> : (tensor<8x128xf32>, tensor<2xi32>) -> tensor<1x1xf32> // CHECK: return %[[RESULT]] : tensor<1x1xf32> } @@ -2065,7 +2065,7 @@ func.func @FoldMeanKeepDim(%arg0: tensor<8x128xf32>) -> tensor<1x128xf32> { func.return %1 : tensor<1x128xf32> // CHECK-LABEL: FoldMeanKeepDim -// CHECK: %[[RESULT:.*]] = "tfl.mean"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<1x128xf32> +// CHECK: %[[RESULT:.*]] = "tfl.mean"(%arg0, %cst) <{keep_dims = true}> : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<1x128xf32> // CHECK: return %[[RESULT]] : tensor<1x128xf32> } @@ -2079,7 +2079,7 @@ func.func @SoftMaxWithNormalization(%arg0: tensor<8x128xf32>) -> tensor<8x128xf3 func.return %4 : tensor<8x128xf32> // CHECK-LABEL: SoftMaxWithNormalization -// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x128xf32>) -> tensor<8x128xf32> +// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) <{beta = 1.000000e+00 : f32}> : (tensor<8x128xf32>) -> tensor<8x128xf32> // CHECK: return %[[RESULT]] : tensor<8x128xf32> } @@ -2091,7 +2091,7 @@ func.func @SoftMaxWithoutNormalization(%arg0: tensor<8x128xf32>) -> tensor<8x128 func.return %2 : tensor<8x128xf32> // CHECK-LABEL: SoftMaxWithoutNormalization -// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x128xf32>) -> tensor<8x128xf32> +// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) <{beta = 1.000000e+00 : f32}> : (tensor<8x128xf32>) -> tensor<8x128xf32> // CHECK: return %[[RESULT]] : tensor<8x128xf32> } @@ -2103,7 +2103,7 @@ func.func @SoftMaxWithoutNormalizationNegAxis(%arg0: tensor<8x128xf32>) -> tenso func.return %2 : tensor<8x128xf32> // CHECK-LABEL: SoftMaxWithoutNormalizationNegAxis -// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x128xf32>) -> tensor<8x128xf32> +// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) <{beta = 1.000000e+00 : f32}> : (tensor<8x128xf32>) -> tensor<8x128xf32> // CHECK: return %[[RESULT]] : tensor<8x128xf32> } @@ -2160,7 +2160,7 @@ func.func @fuseMulIntoConv2d_Splat2D(%arg0: tensor<1x112x112x2xf32>) -> tensor<1 func.return %1 : tensor<1x112x112x2xf32> // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<{{\[\[\[\[}}2.000000e+00, 4.000000e+00]]], {{\[\[\[}}6.000000e+00, 8.000000e+00]]]]> : tensor<2x1x1x2xf32> // CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<2xf32> - // CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x1x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> + // CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x112x112x2xf32>, tensor<2x1x1x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> // CHECK: return %[[RES]] } @@ -2174,9 +2174,9 @@ func.func @AvoidFuseFullyConnectedAddWithSplat2D(%arg0: tensor<1x1x1x1x1xf32>, % func.return %1 : tensor<1x1x1x1x1xf32> - // CHECK-DAG: %[[CST1:.*]] = "tfl.no_value"() {value} : () -> none + // CHECK-DAG: %[[CST1:.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<1x1x1x1x1xf32> - // CHECK: %[[FC_RESULT:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[CST1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1x1x1x1xf32>, tensor<1x1xf32>, none) -> tensor<1x1x1x1x1xf32> + // CHECK: %[[FC_RESULT:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[CST1]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x1x1x1x1xf32>, tensor<1x1xf32>, none) -> tensor<1x1x1x1x1xf32> // CHECK: %[[ADD:.*]] = tfl.add %[[FC_RESULT]], %[[CST2]] {fused_activation_function = "NONE"} : tensor<1x1x1x1x1xf32> // CHECK: return %[[ADD]] : tensor<1x1x1x1x1xf32> } @@ -2215,7 +2215,7 @@ func.func @DontConvertMul1WithBroadcastToIdentity(%arg0: tensor<2xf32>) -> tenso %0 = "tfl.mul"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> func.return %0 : tensor<2x2xf32> // CHECK-DAG: %cst = arith.constant dense<1.000000e+00> : tensor<2x2xf32> - // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // CHECK: return %0 : tensor<2x2xf32> } @@ -2403,13 +2403,13 @@ func.func @EliminateReduceOpsBool(%arg: tensor<1x2x1x3xi1>, %arg_scalar: tensor< // CHECK-DAG: %[[AXIS_1:.*]] = arith.constant dense<1> : tensor<1xi32> // CHECK-DAG: %[[AXIS_2:.*]] = arith.constant dense<2> : tensor<1xi32> // CHECK-DAG: %[[AXIS_3:.*]] = arith.constant dense<3> : tensor<1xi32> - // CHECK: %[[RET_0:.*]] = "tfl.reduce_any"(%arg0, %[[AXIS_0]]) {keep_dims = false} : (tensor<1x2x1x3xi1>, tensor<1xi32>) -> tensor<2x1x3xi1> - // CHECK: %[[RET_1:.*]] = "tfl.reduce_any"(%arg0, %[[AXIS_1]]) {keep_dims = false} : (tensor<1x2x1x3xi1>, tensor<1xi32>) -> tensor<1x1x1x3xi1> - // CHECK: %[[RET_2:.*]] = "tfl.reduce_all"(%arg0, %[[AXIS_1]]) {keep_dims = true} : (tensor<1x2x1x3xi1>, tensor<1xi32>) -> tensor<1x1x3xi1> - // CHECK: %[[RET_3:.*]] = "tfl.reduce_all"(%arg0, %[[AXIS_2]]) {keep_dims = false} : (tensor<1x2x1x3xi1>, tensor<1xi32>) -> tensor<1x2x3xi1> - // CHECK: %[[RET_4:.*]] = "tfl.reduce_all"(%arg0, %[[AXIS_3]]) {keep_dims = false} : (tensor<1x2x1x3xi1>, tensor<1xi32>) -> tensor<1x2x1xi1> - // CHECK: %[[RET_5:.*]] = "tfl.reduce_all"(%arg0, %[[AXIS_3]]) {keep_dims = true} : (tensor<1x2x1x3xi1>, tensor<1xi32>) -> tensor<1x2x1x1xi1> - // CHECK: %[[RET_6:.*]] = "tfl.reduce_all"(%arg2, %arg3) {keep_dims = true} : (tensor, tensor) -> tensor + // CHECK: %[[RET_0:.*]] = "tfl.reduce_any"(%arg0, %[[AXIS_0]]) <{keep_dims = false}> : (tensor<1x2x1x3xi1>, tensor<1xi32>) -> tensor<2x1x3xi1> + // CHECK: %[[RET_1:.*]] = "tfl.reduce_any"(%arg0, %[[AXIS_1]]) <{keep_dims = false}> : (tensor<1x2x1x3xi1>, tensor<1xi32>) -> tensor<1x1x1x3xi1> + // CHECK: %[[RET_2:.*]] = "tfl.reduce_all"(%arg0, %[[AXIS_1]]) <{keep_dims = true}> : (tensor<1x2x1x3xi1>, tensor<1xi32>) -> tensor<1x1x3xi1> + // CHECK: %[[RET_3:.*]] = "tfl.reduce_all"(%arg0, %[[AXIS_2]]) <{keep_dims = false}> : (tensor<1x2x1x3xi1>, tensor<1xi32>) -> tensor<1x2x3xi1> + // CHECK: %[[RET_4:.*]] = "tfl.reduce_all"(%arg0, %[[AXIS_3]]) <{keep_dims = false}> : (tensor<1x2x1x3xi1>, tensor<1xi32>) -> tensor<1x2x1xi1> + // CHECK: %[[RET_5:.*]] = "tfl.reduce_all"(%arg0, %[[AXIS_3]]) <{keep_dims = true}> : (tensor<1x2x1x3xi1>, tensor<1xi32>) -> tensor<1x2x1x1xi1> + // CHECK: %[[RET_6:.*]] = "tfl.reduce_all"(%arg2, %arg3) <{keep_dims = true}> : (tensor, tensor) -> tensor // CHECK: return %arg1, %arg1, %[[RET_0]], %arg0, %[[RET_1]], %[[RET_2]], %[[RET_3]], %arg0, %[[RET_4]], %[[RET_5]], %[[RET_6]] : tensor, tensor, tensor<2x1x3xi1>, tensor<1x2x1x3xi1>, tensor<1x1x1x3xi1>, tensor<1x1x3xi1>, tensor<1x2x3xi1>, tensor<1x2x1x3xi1>, tensor<1x2x1xi1>, tensor<1x2x1x1xi1>, tensor } @@ -2436,13 +2436,13 @@ func.func @EliminateReduceOpsFloat(%arg: tensor<1x2x1x3xf32>, %arg_scalar: tenso // CHECK-DAG: %[[AXIS_1:.*]] = arith.constant dense<1> : tensor<1xi32> // CHECK-DAG: %[[AXIS_2:.*]] = arith.constant dense<2> : tensor<1xi32> // CHECK-DAG: %[[AXIS_3:.*]] = arith.constant dense<3> : tensor<1xi32> - // CHECK: %[[RET_0:.*]] = "tfl.reduce_min"(%arg0, %[[AXIS_0]]) {keep_dims = false} : (tensor<1x2x1x3xf32>, tensor<1xi32>) -> tensor<2x1x3xf32> - // CHECK: %[[RET_1:.*]] = "tfl.reduce_prod"(%arg0, %[[AXIS_1]]) {keep_dims = false} : (tensor<1x2x1x3xf32>, tensor<1xi32>) -> tensor<1x1x1x3xf32> - // CHECK: %[[RET_2:.*]] = "tfl.mean"(%arg0, %[[AXIS_1]]) {keep_dims = true} : (tensor<1x2x1x3xf32>, tensor<1xi32>) -> tensor<1x1x3xf32> - // CHECK: %[[RET_3:.*]] = "tfl.sum"(%arg0, %[[AXIS_2]]) {keep_dims = false} : (tensor<1x2x1x3xf32>, tensor<1xi32>) -> tensor<1x2x3xf32> - // CHECK: %[[RET_4:.*]] = "tfl.reduce_max"(%arg0, %[[AXIS_3]]) {keep_dims = false} : (tensor<1x2x1x3xf32>, tensor<1xi32>) -> tensor<1x2x1xf32> - // CHECK: %[[RET_5:.*]] = "tfl.reduce_prod"(%arg0, %[[AXIS_3]]) {keep_dims = true} : (tensor<1x2x1x3xf32>, tensor<1xi32>) -> tensor<1x2x1x1xf32> - // CHECK: %[[RET_6:.*]] = "tfl.sum"(%arg2, %arg3) {keep_dims = true} : (tensor, tensor) -> tensor + // CHECK: %[[RET_0:.*]] = "tfl.reduce_min"(%arg0, %[[AXIS_0]]) <{keep_dims = false}> : (tensor<1x2x1x3xf32>, tensor<1xi32>) -> tensor<2x1x3xf32> + // CHECK: %[[RET_1:.*]] = "tfl.reduce_prod"(%arg0, %[[AXIS_1]]) <{keep_dims = false}> : (tensor<1x2x1x3xf32>, tensor<1xi32>) -> tensor<1x1x1x3xf32> + // CHECK: %[[RET_2:.*]] = "tfl.mean"(%arg0, %[[AXIS_1]]) <{keep_dims = true}> : (tensor<1x2x1x3xf32>, tensor<1xi32>) -> tensor<1x1x3xf32> + // CHECK: %[[RET_3:.*]] = "tfl.sum"(%arg0, %[[AXIS_2]]) <{keep_dims = false}> : (tensor<1x2x1x3xf32>, tensor<1xi32>) -> tensor<1x2x3xf32> + // CHECK: %[[RET_4:.*]] = "tfl.reduce_max"(%arg0, %[[AXIS_3]]) <{keep_dims = false}> : (tensor<1x2x1x3xf32>, tensor<1xi32>) -> tensor<1x2x1xf32> + // CHECK: %[[RET_5:.*]] = "tfl.reduce_prod"(%arg0, %[[AXIS_3]]) <{keep_dims = true}> : (tensor<1x2x1x3xf32>, tensor<1xi32>) -> tensor<1x2x1x1xf32> + // CHECK: %[[RET_6:.*]] = "tfl.sum"(%arg2, %arg3) <{keep_dims = true}> : (tensor, tensor) -> tensor // CHECK: return %arg1, %arg1, %[[RET_0]], %arg0, %[[RET_1]], %[[RET_2]], %[[RET_3]], %arg0, %[[RET_4]], %[[RET_5]], %[[RET_6]] : tensor, tensor, tensor<2x1x3xf32>, tensor<1x2x1x3xf32>, tensor<1x1x1x3xf32>, tensor<1x1x3xf32>, tensor<1x2x3xf32>, tensor<1x2x1x3xf32>, tensor<1x2x1xf32>, tensor<1x2x1x1xf32>, tensor } @@ -2497,7 +2497,7 @@ func.func @DontRemoveSoftmaxNegativeBetaBeforeArgmax(%arg0: tensor<16x1024xf32>) %1 = "tfl.arg_max"(%0, %cst) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32> func.return %1 : tensor<16xi32> // CHECK-DAG: %[[CST:.*]] = arith.constant dense<-1> : tensor<1xi32> - // CHECK: %[[SOFTMAX:.*]] = "tfl.softmax"(%arg0) {beta = -1.000000e+00 : f32} : (tensor<16x1024xf32>) -> tensor<16x1024xf32> + // CHECK: %[[SOFTMAX:.*]] = "tfl.softmax"(%arg0) <{beta = -1.000000e+00 : f32}> : (tensor<16x1024xf32>) -> tensor<16x1024xf32> // CHECK: %[[ARG_MAX:.*]] = "tfl.arg_max"(%[[SOFTMAX]], %[[CST]]) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32> // CHECK: return %[[ARG_MAX]] : tensor<16xi32> } @@ -2509,7 +2509,7 @@ func.func @DontRemoveSoftmaxNonLastAxisBeforeArgmax(%arg0: tensor<16x1024xf32>) %1 = "tfl.arg_max"(%0, %cst) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32> func.return %1 : tensor<16xi32> // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : tensor<1xi32> - // CHECK: %[[SOFTMAX:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<16x1024xf32>) -> tensor<16x1024xf32> + // CHECK: %[[SOFTMAX:.*]] = "tfl.softmax"(%arg0) <{beta = 1.000000e+00 : f32}> : (tensor<16x1024xf32>) -> tensor<16x1024xf32> // CHECK: %[[ARG_MAX:.*]] = "tfl.arg_max"(%[[SOFTMAX]], %[[CST]]) : (tensor<16x1024xf32>, tensor<1xi32>) -> tensor<16xi32> // CHECK: return %[[ARG_MAX]] : tensor<16xi32> } @@ -2633,7 +2633,7 @@ func.func @FuseAddWithFullyConnectedWithBias(%arg: tensor<2x512xf32>) -> tensor< // 2.0 * 3.0 * 512 + 5.0 = 3077.0 // CHECK-DAG: %[[WEIGHTS:.*]] = arith.constant dense<3.000000e+00> : tensor<1024x512xf32> // CHECK-DAG: %[[BIAS:.*]] = arith.constant dense<3.077000e+03> : tensor<1024xf32> - // CHECK: %[[RESULT:.*]] = "tfl.fully_connected"(%arg0, %[[WEIGHTS]], %[[BIAS]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32> + // CHECK: %[[RESULT:.*]] = "tfl.fully_connected"(%arg0, %[[WEIGHTS]], %[[BIAS]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32> // CHECK: return %[[RESULT]] } @@ -2652,6 +2652,17 @@ func.func @FuseAddWithFullyConnectedWithQuantizedWeight(%arg: tensor<2x512xf32>) // CHECK: tfl.add } +// CHECK-LABEL: @FuseBatchMatMulAndTransposeWithQuantizedWeight +func.func @FuseBatchMatMulAndTransposeWithQuantizedWeight(%arg: tensor<1x2xf32>) -> tensor<1x3xf32> { + %cst_3 = arith.constant dense<[1, 0]> : tensor<2xi32> + %79 = "tfl.pseudo_qconst"() {qtype = tensor<3x2x!quant.uniform:f32:0, {2.378620e-03,2.848260e-03,2.545190e-03}>>, value = dense<10> : tensor<3x2xi8>} : () -> tensor<3x2x!quant.uniform:f32:0, {2.378620e-03,2.848260e-03,2.545190e-03}>> + %80 = "tfl.transpose"(%79, %cst_3) : (tensor<3x2x!quant.uniform:f32:0, {2.378620e-03,2.848260e-03,2.545190e-03}>>, tensor<2xi32>) -> tensor<2x3x!quant.uniform:f32:1, {2.378620e-03,2.848260e-03,2.545190e-03}>> + %81 = "tfl.batch_matmul"(%arg, %80) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<1x2xf32>, tensor<2x3x!quant.uniform:f32:1, {2.378620e-03,2.848260e-03,2.545190e-03}>>) -> tensor<1x3xf32> + func.return %81 : tensor<1x3xf32> + + // CHECK: tfl.fully_connected +} + // CHECK-LABEL: @FuseAddWithFullyConnectedNoBias // Note: Currently not fused. func.func @FuseAddWithFullyConnectedNoBias(%arg: tensor<2x512xf32>) -> tensor<2x1024xf32> { @@ -2666,9 +2677,9 @@ func.func @FuseAddWithFullyConnectedNoBias(%arg: tensor<2x512xf32>) -> tensor<2x // CHECK-DAG: %[[ADDEND:.*]] = arith.constant dense<2.000000e+00> : tensor<512xf32> // CHECK-DAG: %[[WEIGHTS:.*]] = arith.constant dense<3.000000e+00> : tensor<1024x512xf32> - // CHECK-DAG: %[[BIAS:.*]] = "tfl.no_value"() {value} : () -> none - // CHECK: %[[VAL_0:.*]] = tfl.add(%arg0, %[[ADDEND]]) {fused_activation_function = "NONE"} : (tensor<2x512xf32>, tensor<512xf32>) -> tensor<2x512xf32> - // CHECK: %[[VAL_1:.*]] = "tfl.fully_connected"(%[[VAL_0]], %[[WEIGHTS]], %[[BIAS]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x512xf32>, tensor<1024x512xf32>, none) -> tensor<2x1024xf32> + // CHECK-DAG: %[[BIAS:.*]] = "tfl.no_value"() <{value}> : () -> none + // CHECK: %[[VAL_0:.*]] = tfl.add(%arg0, %[[ADDEND]]) <{fused_activation_function = "NONE"}> : (tensor<2x512xf32>, tensor<512xf32>) -> tensor<2x512xf32> + // CHECK: %[[VAL_1:.*]] = "tfl.fully_connected"(%[[VAL_0]], %[[WEIGHTS]], %[[BIAS]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<2x512xf32>, tensor<1024x512xf32>, none) -> tensor<2x1024xf32> // CHECK: return %[[VAL_1]] } @@ -2687,7 +2698,7 @@ func.func @DontFuseAddWithFullyConnectedMismatchedDimensions(%arg: tensor<2x512x // CHECK-DAG: %[[WEIGHTS:.*]] = arith.constant dense<3.000000e+00> : tensor<1024x512xf32> // CHECK-DAG: %[[BIAS:.*]] = arith.constant dense<5.000000e+00> : tensor<1024xf32> // CHECK: %[[VAL_0:.*]] = tfl.add %arg0, %[[ADDEND]] {fused_activation_function = "NONE"} : tensor<2x512xf32> - // CHECK: %[[VAL_1:.*]] = "tfl.fully_connected"(%[[VAL_0]], %[[WEIGHTS]], %[[BIAS]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32> + // CHECK: %[[VAL_1:.*]] = "tfl.fully_connected"(%[[VAL_0]], %[[WEIGHTS]], %[[BIAS]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32> // CHECK: return %[[VAL_1]] } @@ -2704,7 +2715,7 @@ func.func @FuseMulWithFullyConnectedWithBias(%arg: tensor<2x512xf32>) -> tensor< // CHECK-DAG: %[[WEIGHTS:.*]] = arith.constant dense<6.000000e+00> : tensor<1024x512xf32> // CHECK-DAG: %[[BIAS:.*]] = arith.constant dense<5.000000e+00> : tensor<1024xf32> - // CHECK: %[[RESULT:.*]] = "tfl.fully_connected"(%arg0, %[[WEIGHTS]], %[[BIAS]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32> + // CHECK: %[[RESULT:.*]] = "tfl.fully_connected"(%arg0, %[[WEIGHTS]], %[[BIAS]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32> // CHECK: return %[[RESULT]] } @@ -2735,16 +2746,16 @@ func.func @FuseMulWithFullyConnectedNoBias(%arg: tensor<2x512xf32>) -> tensor<2x func.return %1 : tensor<2x1024xf32> // CHECK-DAG: %[[WEIGHTS:.*]] = arith.constant dense<6.000000e+00> : tensor<1024x512xf32> - // CHECK-DAG: %[[BIAS:.*]] = "tfl.no_value"() {value} : () -> none - // CHECK: %[[VAL_0:.*]] = "tfl.fully_connected"(%arg0, %[[WEIGHTS]], %[[BIAS]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x512xf32>, tensor<1024x512xf32>, none) -> tensor<2x1024xf32> + // CHECK-DAG: %[[BIAS:.*]] = "tfl.no_value"() <{value}> : () -> none + // CHECK: %[[VAL_0:.*]] = "tfl.fully_connected"(%arg0, %[[WEIGHTS]], %[[BIAS]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<2x512xf32>, tensor<1024x512xf32>, none) -> tensor<2x1024xf32> // CHECK: return %[[VAL_0]] // NoFusing-LABEL: FuseMulWithFullyConnectedNoBias // NoFusing-DAG: %[[MWEIGHTS:.*]] = arith.constant dense<2.000000e+00> : tensor<512xf32> // NoFusing-DAG: %[[WEIGHTS:.*]] = arith.constant dense<3.000000e+00> : tensor<1024x512xf32> - // NoFusing-DAG: %[[BIAS:.*]] = "tfl.no_value"() {value} : () -> none - // NoFusing: %[[MUL:.*]] = tfl.mul(%arg0, %[[MWEIGHTS]]) {fused_activation_function = "NONE"} : (tensor<2x512xf32>, tensor<512xf32>) -> tensor<2x512xf32> - // NoFusing: %[[VAL:.*]] = "tfl.fully_connected"(%[[MUL]], %[[WEIGHTS]], %[[BIAS]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x512xf32>, tensor<1024x512xf32>, none) -> tensor<2x1024xf32> + // NoFusing-DAG: %[[BIAS:.*]] = "tfl.no_value"() <{value}> : () -> none + // NoFusing: %[[MUL:.*]] = tfl.mul(%arg0, %[[MWEIGHTS]]) <{fused_activation_function = "NONE"}> : (tensor<2x512xf32>, tensor<512xf32>) -> tensor<2x512xf32> + // NoFusing: %[[VAL:.*]] = "tfl.fully_connected"(%[[MUL]], %[[WEIGHTS]], %[[BIAS]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<2x512xf32>, tensor<1024x512xf32>, none) -> tensor<2x1024xf32> // NoFusing: return %[[VAL]] } @@ -2760,8 +2771,8 @@ func.func @FuseMulWithFullyConnectedNoBiasWithOptionalAttribute(%arg: tensor<2x5 func.return %1 : tensor<2x1024xf32> // CHECK-DAG: %[[WEIGHTS:.*]] = arith.constant dense<6.000000e+00> : tensor<1024x512xf32> - // CHECK-DAG: %[[BIAS:.*]] = "tfl.no_value"() {value} : () -> none - // CHECK: %[[VAL_0:.*]] = "tfl.fully_connected"(%arg0, %[[WEIGHTS]], %[[BIAS]]) {asymmetric_quantize_inputs = true, + // CHECK-DAG: %[[BIAS:.*]] = "tfl.no_value"() <{value}> : () -> none + // CHECK: %[[VAL_0:.*]] = "tfl.fully_connected"(%arg0, %[[WEIGHTS]], %[[BIAS]]) <{asymmetric_quantize_inputs = true, } // CHECK-LABEL: @DontFuseMulWithFullyConnectedMismatchedDimensions @@ -2779,7 +2790,7 @@ func.func @DontFuseMulWithFullyConnectedMismatchedDimensions(%arg: tensor<2x512x // CHECK-DAG: %[[WEIGHTS:.*]] = arith.constant dense<3.000000e+00> : tensor<1024x512xf32> // CHECK-DAG: %[[BIAS:.*]] = arith.constant dense<5.000000e+00> : tensor<1024xf32> // CHECK: %[[VAL_0:.*]] = tfl.mul %arg0, %[[MULTIPLIER]] {fused_activation_function = "NONE"} : tensor<2x512xf32> - // CHECK: %[[VAL_1:.*]] = "tfl.fully_connected"(%[[VAL_0]], %[[WEIGHTS]], %[[BIAS]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32> + // CHECK: %[[VAL_1:.*]] = "tfl.fully_connected"(%[[VAL_0]], %[[WEIGHTS]], %[[BIAS]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32> // CHECK: return %[[VAL_1]] } @@ -2789,7 +2800,7 @@ func.func @RemoveReshapeBeforeFullyConnectedExpandDims0(%arg0: tensor<128x64xf32 %0 = "tfl.reshape"(%arg0, %cst) : (tensor<128x64xf32>, tensor<3xi32>) -> tensor<1x128x64xf32> %1 = "tfl.fully_connected"(%0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x128x64xf32>, tensor<32x64xf32>, tensor<32xf32>) -> tensor<128x32xf32> func.return %1 : tensor<128x32xf32> - // CHECK: %[[FULLY_CONNECTED:.*]] = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<128x64xf32>, tensor<32x64xf32>, tensor<32xf32>) -> tensor<128x32xf32> + // CHECK: %[[FULLY_CONNECTED:.*]] = "tfl.fully_connected"(%arg0, %arg1, %arg2) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<128x64xf32>, tensor<32x64xf32>, tensor<32xf32>) -> tensor<128x32xf32> // CHECK: return %[[FULLY_CONNECTED]] : tensor<128x32xf32> } @@ -2799,7 +2810,7 @@ func.func @RemoveReshapeBeforeFullyConnectedReshape(%arg0: tensor<128x64xf32>, % %0 = "tfl.reshape"(%arg0, %cst) : (tensor<128x64xf32>, tensor<3xi32>) -> tensor<4x32x64xf32> %1 = "tfl.fully_connected"(%0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x32x64xf32>, tensor<32x64xf32>, tensor<32xf32>) -> tensor<128x32xf32> func.return %1 : tensor<128x32xf32> - // CHECK: %[[FULLY_CONNECTED:.*]] = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<128x64xf32>, tensor<32x64xf32>, tensor<32xf32>) -> tensor<128x32xf32> + // CHECK: %[[FULLY_CONNECTED:.*]] = "tfl.fully_connected"(%arg0, %arg1, %arg2) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<128x64xf32>, tensor<32x64xf32>, tensor<32xf32>) -> tensor<128x32xf32> // CHECK: return %[[FULLY_CONNECTED]] : tensor<128x32xf32> } @@ -2811,7 +2822,7 @@ func.func @DontRemoveReshapeBeforeFullyConnectedKeepNumDims(%arg0: tensor<128x64 func.return %1 : tensor<1x128x32xf32> // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[1, 128, 64]> : tensor<3xi32> // CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<128x64xf32>, tensor<3xi32>) -> tensor<1x128x64xf32> - // CHECK: %[[FULLY_CONNECTED:.*]] = "tfl.fully_connected"(%[[RESHAPE]], %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<1x128x64xf32>, tensor<32x64xf32>, tensor<32xf32>) -> tensor<1x128x32xf32> + // CHECK: %[[FULLY_CONNECTED:.*]] = "tfl.fully_connected"(%[[RESHAPE]], %arg1, %arg2) <{fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<1x128x64xf32>, tensor<32x64xf32>, tensor<32xf32>) -> tensor<1x128x32xf32> // CHECK: return %[[FULLY_CONNECTED]] : tensor<1x128x32xf32> } @@ -2823,7 +2834,7 @@ func.func @DontRemoveReshapeBeforeFullyConnectedChangeLastDim(%arg0: tensor<128x func.return %1 : tensor<256x32xf32> // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[1, 256, 32]> : tensor<3xi32> // CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<128x64xf32>, tensor<3xi32>) -> tensor<1x256x32xf32> - // CHECK: %[[FULLY_CONNECTED:.*]] = "tfl.fully_connected"(%[[RESHAPE]], %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256x32xf32>, tensor<32x32xf32>, tensor<32xf32>) -> tensor<256x32xf32> + // CHECK: %[[FULLY_CONNECTED:.*]] = "tfl.fully_connected"(%[[RESHAPE]], %arg1, %arg2) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x256x32xf32>, tensor<32x32xf32>, tensor<32xf32>) -> tensor<256x32xf32> // CHECK: return %[[FULLY_CONNECTED]] : tensor<256x32xf32> } @@ -2838,8 +2849,8 @@ func.func @DontFuseAddWithConvActivationFunc(%arg0: tensor<1x3x1x1xf32>) -> tens // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.500000e+00> : tensor<1xf32> // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : tensor<3xf32> // CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<1.100000e+00> : tensor<3x2x1x1xf32> - // CHECK: %[[ADD:.*]] = tfl.add(%arg0, %[[CST]]) {fused_activation_function = "RELU6"} : (tensor<1x3x1x1xf32>, tensor<1xf32>) -> tensor<1x3x1x1xf32> - // CHECK: %[[CONV:.*]] = "tfl.conv_2d"(%[[ADD]], %[[CST_2]], %[[CST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x1x1xf32>, tensor<3x2x1x1xf32>, tensor<3xf32>) -> tensor<1x2x1x3xf32> + // CHECK: %[[ADD:.*]] = tfl.add(%arg0, %[[CST]]) <{fused_activation_function = "RELU6"}> : (tensor<1x3x1x1xf32>, tensor<1xf32>) -> tensor<1x3x1x1xf32> + // CHECK: %[[CONV:.*]] = "tfl.conv_2d"(%[[ADD]], %[[CST_2]], %[[CST_1]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x3x1x1xf32>, tensor<3x2x1x1xf32>, tensor<3xf32>) -> tensor<1x2x1x3xf32> // CHECK: return %[[CONV]] } @@ -2874,7 +2885,7 @@ func.func @replaceReshapeEqualWithOneHot(%arg: tensor<2x1xi32>) -> tensor<2x3xi1 // CHECK-DAG: %[[CST3:.*]] = arith.constant dense : tensor // CHECK-DAG: %[[CST4:.*]] = arith.constant dense<2> : tensor<1xi32> // CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST4]]) : (tensor<2x1xi32>, tensor<1xi32>) -> tensor<2xi32> - // CHECK: %[[RES:.*]] = "tfl.one_hot"(%[[RESHAPE]], %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xi1> + // CHECK: %[[RES:.*]] = "tfl.one_hot"(%[[RESHAPE]], %[[CST1]], %[[CST2]], %[[CST3]]) <{axis = -1 : i32}> : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xi1> } // CHECK-LABEL: ReplaceReshapeEqualWithOneHotWithBatchingDim @@ -2888,7 +2899,7 @@ func.func @ReplaceReshapeEqualWithOneHotWithBatchingDim(%arg: tensor<2x2x1xi32>) // CHECK-DAG: %[[CST3:.*]] = arith.constant dense : tensor // CHECK-DAG: %[[CST4:.*]] = arith.constant dense<2> : tensor<2xi32> // CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST4]]) : (tensor<2x2x1xi32>, tensor<2xi32>) -> tensor<2x2xi32> - // CHECK: %[[RES:.*]] = "tfl.one_hot"(%[[RESHAPE]], %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<2x2xi32>, tensor, tensor, tensor) -> tensor<2x2x3xi1> + // CHECK: %[[RES:.*]] = "tfl.one_hot"(%[[RESHAPE]], %[[CST1]], %[[CST2]], %[[CST3]]) <{axis = -1 : i32}> : (tensor<2x2xi32>, tensor, tensor, tensor) -> tensor<2x2x3xi1> } // CHECK-LABEL: noReplaceReshapeEqualWithOneHotBadShape @@ -2937,7 +2948,7 @@ func.func @ReplaceReshapeEqualOneHotDynamicBatch(%arg0: tensor) -> (tenso // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<-1> : tensor // CHECK: %[[EXPAND_DIMS:.*]] = "tfl.expand_dims"(%arg0, %[[CST_3]]) : (tensor, tensor) -> tensor // CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%0, %[[CST]]) : (tensor, tensor<1xi32>) -> tensor - // CHECK: %[[ONE_HOT:.*]] = "tfl.one_hot"(%1, %[[CST_0]], %[[CST_1]], %[[CST_2]]) {axis = -1 : i32} : (tensor, tensor, tensor, tensor) -> tensor + // CHECK: %[[ONE_HOT:.*]] = "tfl.one_hot"(%1, %[[CST_0]], %[[CST_1]], %[[CST_2]]) <{axis = -1 : i32}> : (tensor, tensor, tensor, tensor) -> tensor // CHECK-NEXT: return %[[ONE_HOT]] } @@ -2999,8 +3010,8 @@ func.func @fuseOneHotCast(%arg: tensor<2xi32>) -> (tensor<2x3xf32>, tensor<2x3xf // CHECK-DAG: %[[CST3:.*]] = arith.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[CST4:.*]] = arith.constant dense<5.000000e+00> : tensor // CHECK-DAG: %[[CST5:.*]] = arith.constant dense<7.000000e+00> : tensor - // CHECK: %[[RES1:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xf32> - // CHECK: %[[RES2:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST4]], %[[CST5]]) {axis = -1 : i32} : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xf32> + // CHECK: %[[RES1:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) <{axis = -1 : i32}> : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xf32> + // CHECK: %[[RES2:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST4]], %[[CST5]]) <{axis = -1 : i32}> : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xf32> } // CHECK-LABEL: replaceOneHotFullyConnectedWithLookup @@ -3039,9 +3050,9 @@ func.func @dontReplaceOneHotFullyConnectedWithLookupBadIndexType(%arg: tensor<2x // CHECK-DAG: %[[CST2:.*]] = arith.constant dense<1.000000e+00> : tensor // CHECK-DAG: %[[CST3:.*]] = arith.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[CST4:.*]] = arith.constant dense<7.000000e+00> : tensor<5x3xf32> - // CHECK-DAG: %[[CST5:.*]] = "tfl.no_value"() {value} : () -> none - // CHECK: %[[TMP:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<2xi64>, tensor, tensor, tensor) -> tensor<2x3xf32> - // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%[[TMP]], %[[CST4]], %[[CST5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x3xf32>, tensor<5x3xf32>, none) -> tensor<2x5xf32> + // CHECK-DAG: %[[CST5:.*]] = "tfl.no_value"() <{value}> : () -> none + // CHECK: %[[TMP:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) <{axis = -1 : i32}> : (tensor<2xi64>, tensor, tensor, tensor) -> tensor<2x3xf32> + // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%[[TMP]], %[[CST4]], %[[CST5]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<2x3xf32>, tensor<5x3xf32>, none) -> tensor<2x5xf32> // CHECK: return %[[RES]] : tensor<2x5xf32> } @@ -3064,9 +3075,9 @@ func.func @dontReplaceOneHotFullyConnectedWithLookupBadIndexTypeWithOptionalAttr // CHECK-DAG: %[[CST2:.*]] = arith.constant dense<1.000000e+00> : tensor // CHECK-DAG: %[[CST3:.*]] = arith.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[CST4:.*]] = arith.constant dense<7.000000e+00> : tensor<5x3xf32> - // CHECK-DAG: %[[CST5:.*]] = "tfl.no_value"() {value} : () -> none - // CHECK: %[[TMP:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<2xi64>, tensor, tensor, tensor) -> tensor<2x3xf32> - // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%[[TMP]], %[[CST4]], %[[CST5]]) {asymmetric_quantize_inputs = true, + // CHECK-DAG: %[[CST5:.*]] = "tfl.no_value"() <{value}> : () -> none + // CHECK: %[[TMP:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) <{axis = -1 : i32}> : (tensor<2xi64>, tensor, tensor, tensor) -> tensor<2x3xf32> + // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%[[TMP]], %[[CST4]], %[[CST5]]) <{asymmetric_quantize_inputs = true, } // CHECK-LABEL: ReplaceOneHotFullyConnectedWithLookup2DRank @@ -3109,9 +3120,9 @@ func.func @dontReplaceOneHotFullyConnectedWithLookupBadOn(%arg: tensor<2xi32>) - // CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor // CHECK-DAG: %[[CST3:.*]] = arith.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[CST4:.*]] = arith.constant dense<7.000000e+00> : tensor<5x3xf32> - // CHECK-DAG: %[[CST5:.*]] = "tfl.no_value"() {value} : () -> none - // CHECK: %[[TMP:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xf32> - // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%[[TMP]], %[[CST4]], %[[CST5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x3xf32>, tensor<5x3xf32>, none) -> tensor<2x5xf32> + // CHECK-DAG: %[[CST5:.*]] = "tfl.no_value"() <{value}> : () -> none + // CHECK: %[[TMP:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) <{axis = -1 : i32}> : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xf32> + // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%[[TMP]], %[[CST4]], %[[CST5]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<2x3xf32>, tensor<5x3xf32>, none) -> tensor<2x5xf32> // CHECK: return %[[RES]] : tensor<2x5xf32> } @@ -3133,9 +3144,9 @@ func.func @dontReplaceOneHotFullyConnectedWithLookupBadOff(%arg: tensor<2xi32>) // CHECK-DAG: %[[CST2:.*]] = arith.constant dense<1.000000e+00> : tensor // CHECK-DAG: %[[CST3:.*]] = arith.constant dense<-1.000000e+00> : tensor // CHECK-DAG: %[[CST4:.*]] = arith.constant dense<7.000000e+00> : tensor<5x3xf32> - // CHECK-DAG: %[[CST5:.*]] = "tfl.no_value"() {value} : () -> none - // CHECK: %[[TMP:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xf32> - // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%[[TMP]], %[[CST4]], %[[CST5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x3xf32>, tensor<5x3xf32>, none) -> tensor<2x5xf32> + // CHECK-DAG: %[[CST5:.*]] = "tfl.no_value"() <{value}> : () -> none + // CHECK: %[[TMP:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) <{axis = -1 : i32}> : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xf32> + // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%[[TMP]], %[[CST4]], %[[CST5]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<2x3xf32>, tensor<5x3xf32>, none) -> tensor<2x5xf32> // CHECK: return %[[RES]] : tensor<2x5xf32> } @@ -3158,8 +3169,8 @@ func.func @dontReplaceOneHotFullyConnectedWithLookupBadBias(%arg: tensor<2xi32>) // CHECK-DAG: %[[CST3:.*]] = arith.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[CST4:.*]] = arith.constant dense<7.000000e+00> : tensor<5x3xf32> // CHECK-DAG: %[[CST5:.*]] = arith.constant dense<1.100000e+01> : tensor - // CHECK: %[[TMP:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xf32> - // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%[[TMP]], %[[CST4]], %[[CST5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x3xf32>, tensor<5x3xf32>, tensor) -> tensor<2x5xf32> + // CHECK: %[[TMP:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) <{axis = -1 : i32}> : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xf32> + // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%[[TMP]], %[[CST4]], %[[CST5]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<2x3xf32>, tensor<5x3xf32>, tensor) -> tensor<2x5xf32> // CHECK: return %[[RES]] : tensor<2x5xf32> } @@ -3291,10 +3302,10 @@ func.func @eliminateCumSumCheckIndices(%arg: tensor<1x2x1x3xf32>) -> (tensor<1x2 // CHECK-DAG: %[[AXIS_M1:.*]] = arith.constant dense<-1> : tensor // CHECK-DAG: %[[AXIS_P1:.*]] = arith.constant dense<1> : tensor // CHECK-DAG: %[[AXIS_P3:.*]] = arith.constant dense<3> : tensor - // CHECK: %[[RES_M3:.*]] = "tfl.cumsum"(%arg0, %[[AXIS_M3]]) {exclusive = false, reverse = false} : (tensor<1x2x1x3xf32>, tensor) -> tensor<1x2x1x3xf32> - // CHECK: %[[RES_M1:.*]] = "tfl.cumsum"(%arg0, %[[AXIS_M1]]) {exclusive = false, reverse = false} : (tensor<1x2x1x3xf32>, tensor) -> tensor<1x2x1x3xf32> - // CHECK: %[[RES_P1:.*]] = "tfl.cumsum"(%arg0, %[[AXIS_P1]]) {exclusive = false, reverse = false} : (tensor<1x2x1x3xf32>, tensor) -> tensor<1x2x1x3xf32> - // CHECK: %[[RES_P3:.*]] = "tfl.cumsum"(%arg0, %[[AXIS_P3]]) {exclusive = false, reverse = false} : (tensor<1x2x1x3xf32>, tensor) -> tensor<1x2x1x3xf32> + // CHECK: %[[RES_M3:.*]] = "tfl.cumsum"(%arg0, %[[AXIS_M3]]) <{exclusive = false, reverse = false}> : (tensor<1x2x1x3xf32>, tensor) -> tensor<1x2x1x3xf32> + // CHECK: %[[RES_M1:.*]] = "tfl.cumsum"(%arg0, %[[AXIS_M1]]) <{exclusive = false, reverse = false}> : (tensor<1x2x1x3xf32>, tensor) -> tensor<1x2x1x3xf32> + // CHECK: %[[RES_P1:.*]] = "tfl.cumsum"(%arg0, %[[AXIS_P1]]) <{exclusive = false, reverse = false}> : (tensor<1x2x1x3xf32>, tensor) -> tensor<1x2x1x3xf32> + // CHECK: %[[RES_P3:.*]] = "tfl.cumsum"(%arg0, %[[AXIS_P3]]) <{exclusive = false, reverse = false}> : (tensor<1x2x1x3xf32>, tensor) -> tensor<1x2x1x3xf32> // CHECK: return %arg0, %[[RES_M3]], %arg0, %[[RES_M1]], %arg0, %[[RES_P1]], %arg0, %[[RES_P3]] : tensor<1x2x1x3xf32>, tensor<1x2x1x3xf32>, tensor<1x2x1x3xf32>, tensor<1x2x1x3xf32>, tensor<1x2x1x3xf32>, tensor<1x2x1x3xf32>, tensor<1x2x1x3xf32>, tensor<1x2x1x3xf32> } @@ -3308,8 +3319,8 @@ func.func @eliminateCumSumCheckAttributes(%arg: tensor<1x2x1x3xf32>) -> (tensor< func.return %res_ff, %res_ft, %res_tf, %res_tt: tensor<1x2x1x3xf32>, tensor<1x2x1x3xf32>, tensor<1x2x1x3xf32>, tensor<1x2x1x3xf32> // CHECK: %[[AXIS:.*]] = arith.constant dense<2> : tensor - // CHECK: %[[RES_TF:.*]] = "tfl.cumsum"(%arg0, %[[AXIS]]) {exclusive = true, reverse = false} : (tensor<1x2x1x3xf32>, tensor) -> tensor<1x2x1x3xf32> - // CHECK: %[[RES_TT:.*]] = "tfl.cumsum"(%arg0, %[[AXIS]]) {exclusive = true, reverse = true} : (tensor<1x2x1x3xf32>, tensor) -> tensor<1x2x1x3xf32> + // CHECK: %[[RES_TF:.*]] = "tfl.cumsum"(%arg0, %[[AXIS]]) <{exclusive = true, reverse = false}> : (tensor<1x2x1x3xf32>, tensor) -> tensor<1x2x1x3xf32> + // CHECK: %[[RES_TT:.*]] = "tfl.cumsum"(%arg0, %[[AXIS]]) <{exclusive = true, reverse = true}> : (tensor<1x2x1x3xf32>, tensor) -> tensor<1x2x1x3xf32> // CHECK: return %arg0, %arg0, %[[RES_TF]], %[[RES_TT]] : tensor<1x2x1x3xf32>, tensor<1x2x1x3xf32>, tensor<1x2x1x3xf32>, tensor<1x2x1x3xf32> } @@ -3325,7 +3336,7 @@ func.func @gelu(%arg0: tensor<3xf32>) -> tensor<3xf32> { func.return %4 : tensor<3xf32> // CHECK-LABEL:gelu -// CHECK: "tfl.gelu"(%arg0) {approximate = false} : (tensor<3xf32>) -> tensor<3xf32> +// CHECK: "tfl.gelu"(%arg0) <{approximate = false}> : (tensor<3xf32>) -> tensor<3xf32> } func.func @gelu_no_match(%arg0: tensor<3xf32>) -> tensor<3xf32> { @@ -3377,7 +3388,7 @@ func.func @gelu_approximate(%arg0: tensor<3xf32>) -> tensor<3xf32> { func.return %7 : tensor<3xf32> // CHECK-LABEL:gelu_approximate -// CHECK: "tfl.gelu"(%arg0) {approximate = true} : (tensor<3xf32>) -> tensor<3xf32> +// CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> } func.func @gelu_approximate1(%arg0: tensor<3xf32>) -> tensor<3xf32> { @@ -3397,7 +3408,7 @@ func.func @gelu_approximate1(%arg0: tensor<3xf32>) -> tensor<3xf32> { func.return %7 : tensor<3xf32> // CHECK-LABEL:gelu_approximate -// CHECK: "tfl.gelu"(%arg0) {approximate = true} : (tensor<3xf32>) -> tensor<3xf32> +// CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> } func.func @gelu_approximate_no_match(%arg0: tensor<3xf32>) -> tensor<3xf32> { @@ -3456,7 +3467,7 @@ func.func @eliminateExtraSelectLhs(%arg0: tensor<4x2x1xf32>, %arg1: tensor<4x2x1 // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2x2xf32> // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<2.000000e+00> : tensor<2xf32> - // CHECK: %[[FC:.*]] = "tfl.fully_connected"(%arg0, %[[CST]], %[[CST_1]]) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<4x2x1xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2x1xf32> + // CHECK: %[[FC:.*]] = "tfl.fully_connected"(%arg0, %[[CST]], %[[CST_1]]) <{asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<4x2x1xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2x1xf32> // CHECK-NEXT: %[[SELECT:.*]] = "tfl.select_v2" // CHECK-NEXT: return %[[SELECT]] } @@ -3475,7 +3486,7 @@ func.func @eliminateExtraSelectRhs(%arg0: tensor<4x2x1xf32>, %arg1: tensor<4x2x1 // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2x2xf32> // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<2.000000e+00> : tensor<2xf32> - // CHECK: %[[FC:.*]] = "tfl.fully_connected"(%arg0, %[[CST]], %[[CST_1]]) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<4x2x1xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2x1xf32> + // CHECK: %[[FC:.*]] = "tfl.fully_connected"(%arg0, %[[CST]], %[[CST_1]]) <{asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<4x2x1xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2x1xf32> // CHECK-NEXT: %[[SELECT:.*]] = "tfl.select_v2" // CHECK-NEXT: return %[[SELECT]] } @@ -3497,7 +3508,7 @@ func.func @DontEliminateExtraSelect(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi1 // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<2.000000e+00> : tensor<2xf32> // CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : tensor<4x2xf32> // CHECK: %[[SELECT:.*]] = "tfl.select_v2"(%arg1, %arg0, %[[CST_2]]) : (tensor<4x2xi1>, tensor<4x2xf32>, tensor<4x2xf32>) -> tensor<4x2xf32> - // CHECK: %[[FC:.*]] = "tfl.fully_connected"(%[[SELECT]], %[[CST]], %[[CST_1]]) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + // CHECK: %[[FC:.*]] = "tfl.fully_connected"(%[[SELECT]], %[[CST]], %[[CST_1]]) <{asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> // CHECK-NEXT: %[[SELECT_1:.*]] = "tfl.select_v2" // CHECK-NEXT: return %[[SELECT_1]] } @@ -3549,8 +3560,8 @@ func.func @fuseReluToMin1_StaticShapeWithSameShapeCst_Float2(%arg0: tensor<2x2xf // CHECK-LABEL: func @fuseAddAndStridedSlice func.func @fuseAddAndStridedSlice(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<4xi32> { // CHECK-DAG: %[[cst:.*]] = arith.constant dense<1> : tensor<1xi32> - // CHECK-DAG: %[[c0:.*]] = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %1 = "tfl.strided_slice"(%arg0, %arg1, %[[cst]], %[[c0]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = true, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + // CHECK-DAG: %[[c0:.*]] = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> + // CHECK: %1 = "tfl.strided_slice"(%arg0, %arg1, %[[cst]], %[[c0]]) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = true, shrink_axis_mask = 0 : i32}> : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> %cst_0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor %cst_1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> @@ -3562,8 +3573,8 @@ func.func @fuseAddAndStridedSlice(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> // CHECK-LABEL: func @fuseSubAndStridedSlice func.func @fuseSubAndStridedSlice(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<4xi32> { // CHECK-DAG: %[[cst:.*]] = arith.constant dense<1> : tensor<1xi32> - // CHECK-DAG: %[[c0:.*]] = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %1 = "tfl.strided_slice"(%arg0, %arg1, %[[cst]], %[[c0]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = true, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + // CHECK-DAG: %[[c0:.*]] = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> + // CHECK: %1 = "tfl.strided_slice"(%arg0, %arg1, %[[cst]], %[[c0]]) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = true, shrink_axis_mask = 0 : i32}> : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> %cst_0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor %cst_1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> @@ -3574,9 +3585,9 @@ func.func @fuseSubAndStridedSlice(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> // CHECK-LABEL: func @dontFuseAddAndStridedSliceNonConstantStride func.func @dontFuseAddAndStridedSliceNonConstantStrides(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>) -> tensor<4xi32> { - // CHECK-DAG: %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: %1 = tfl.add(%arg1, %0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> - // CHECK: %2 = "tfl.strided_slice"(%arg0, %arg1, %1, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + // CHECK-DAG: %0 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor + // CHECK: %1 = tfl.add(%arg1, %0) <{fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor) -> tensor<1xi32> + // CHECK: %2 = "tfl.strided_slice"(%arg0, %arg1, %1, %arg2) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> %cst = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor %0 = "tfl.add"(%arg1, %cst) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> @@ -3586,9 +3597,9 @@ func.func @dontFuseAddAndStridedSliceNonConstantStrides(%arg0: tensor<4xi32>, %a // CHECK-LABEL: func @dontFuseAddAndStridedSliceOffset func.func @dontFuseAddAndStridedSliceOffset(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<4xi32> { - // CHECK-DAG: %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: %1 = tfl.add(%arg2, %0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> - // CHECK: %2 = "tfl.strided_slice"(%arg0, %arg1, %1, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + // CHECK-DAG: %0 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor + // CHECK: %1 = tfl.add(%arg2, %0) <{fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor) -> tensor<1xi32> + // CHECK: %2 = "tfl.strided_slice"(%arg0, %arg1, %1, %arg3) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> %cst = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor %0 = "tfl.add"(%arg2, %cst) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> @@ -3599,7 +3610,7 @@ func.func @dontFuseAddAndStridedSliceOffset(%arg0: tensor<4xi32>, %arg1: tensor< // CHECK-LABEL: func @dontFuseAddAndStridedSliceNonConstantOffset func.func @dontFuseAddAndStridedSliceNonConstantOffset(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>) -> tensor<4xi32> { // CHECK: %0 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32> - // CHECK: "tfl.strided_slice"(%arg0, %arg1, %0, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + // CHECK: "tfl.strided_slice"(%arg0, %arg1, %0, %arg2) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> %0 = "tfl.add"(%arg1, %arg1) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> %1 = "tfl.strided_slice"(%arg0, %arg1, %0, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> @@ -3608,10 +3619,10 @@ func.func @dontFuseAddAndStridedSliceNonConstantOffset(%arg0: tensor<4xi32>, %ar // CHECK-LABEL: func @dontFuseAddAndStridedSliceBeginMask func.func @dontFuseAddAndStridedSliceBeginMask(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<4xi32> { - // CHECK-DAG: %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor - // CHECK-DAG: %1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %2 = tfl.add(%arg1, %0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> - // CHECK: %3 = "tfl.strided_slice"(%arg0, %arg1, %2, %1) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + // CHECK-DAG: %0 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor + // CHECK-DAG: %1 = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> + // CHECK: %2 = tfl.add(%arg1, %0) <{fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor) -> tensor<1xi32> + // CHECK: %3 = "tfl.strided_slice"(%arg0, %arg1, %2, %1) <{begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> %cst_0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor %cst_1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> @@ -3622,10 +3633,10 @@ func.func @dontFuseAddAndStridedSliceBeginMask(%arg0: tensor<4xi32>, %arg1: tens // CHECK-LABEL: func @dontFuseAddAndStridedSliceEndMask func.func @dontFuseAddAndStridedSliceEndMask(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<4xi32> { - // CHECK-DAG: %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor - // CHECK-DAG: %1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %2 = tfl.add(%arg1, %0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> - // CHECK: %3 = "tfl.strided_slice"(%arg0, %arg1, %2, %1) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 1 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + // CHECK-DAG: %0 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor + // CHECK-DAG: %1 = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> + // CHECK: %2 = tfl.add(%arg1, %0) <{fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor) -> tensor<1xi32> + // CHECK: %3 = "tfl.strided_slice"(%arg0, %arg1, %2, %1) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 1 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> %cst_0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor %cst_1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> @@ -3636,10 +3647,10 @@ func.func @dontFuseAddAndStridedSliceEndMask(%arg0: tensor<4xi32>, %arg1: tensor // CHECK-LABEL: func @dontFuseAddAndStridedSliceEllipsisMask func.func @dontFuseAddAndStridedSliceEllipsisMask(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<4xi32> { - // CHECK-DAG: %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor - // CHECK-DAG: %1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %2 = tfl.add(%arg1, %0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> - // CHECK: %3 = "tfl.strided_slice"(%arg0, %arg1, %2, %1) {begin_mask = 0 : i32, ellipsis_mask = 1 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + // CHECK-DAG: %0 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor + // CHECK-DAG: %1 = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> + // CHECK: %2 = tfl.add(%arg1, %0) <{fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor) -> tensor<1xi32> + // CHECK: %3 = "tfl.strided_slice"(%arg0, %arg1, %2, %1) <{begin_mask = 0 : i32, ellipsis_mask = 1 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> %cst_0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor %cst_1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> @@ -3752,7 +3763,7 @@ func.func @FuseReshapeAndTransposeAroundBatchMatmul(%arg0: tensor<1x128x1024xf32 %cst_3 = arith.constant dense<[2, 0, 1]> : tensor<3xi32> %0 = "tfl.transpose"(%arg0, %cst_3) : (tensor<1x128x1024xf32>, tensor<3xi32>) -> tensor<1024x1x128xf32> %1 = "tfl.reshape"(%0, %cst_2) : (tensor<1024x1x128xf32>, tensor<2xi32>) -> tensor<1024x128xf32> - // CHECK: %[[BMM:.*]] = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} + // CHECK: %[[BMM:.*]] = "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> // CHECK-NOT: tfl.reshape // CHECK-NOT: tfl.transpose %2 = "tfl.batch_matmul"(%arg1, %1) {adj_x = true, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<1024x16xf32>, tensor<1024x128xf32>) -> tensor<16x128xf32> @@ -3771,7 +3782,7 @@ func.func @FuseReshapeAndTransposeAroundBatchMatmulWithLargerThan3Rank(%arg0: te %0 = "tfl.transpose"(%arg0, %cst_3) : (tensor<1x128x4x256xf32>, tensor<4xi32>) -> tensor<4x256x1x128xf32> %1 = "tfl.reshape"(%0, %cst_2) : (tensor<4x256x1x128xf32>, tensor<2xi32>) -> tensor<1024x128xf32> // CHECK: %[[RESHAE_ARG0:.*]] = "tfl.reshape"(%arg0, %[[CST:.*]]) : (tensor<1x128x4x256xf32>, tensor<3xi32>) -> tensor<1x128x1024xf32> - // CHECK: %[[BMM:.*]] = "tfl.batch_matmul"(%[[RESHAE_ARG0]], %arg1) {adj_x = false, adj_y = true, asymmetric_quantize_inputs = false} + // CHECK: %[[BMM:.*]] = "tfl.batch_matmul"(%[[RESHAE_ARG0]], %arg1) <{adj_x = false, adj_y = true, asymmetric_quantize_inputs = false}> // CHECK-NOT: tfl.reshape // CHECK-NOT: tfl.transpose %2 = "tfl.batch_matmul"(%arg1, %1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<16x1024xf32>, tensor<1024x128xf32>) -> tensor<16x128xf32> @@ -3812,7 +3823,7 @@ func.func @FuseTransposeReshapeIntoBatchMatmul(%arg0: tensor<4x1024xf32>, %arg1: %0 = "tfl.transpose"(%arg1, %cst_1) : (tensor<8x4x256xf32>, tensor<3xi32>) -> tensor<4x256x8xf32> %1 = "tfl.reshape"(%0, %cst_0) : (tensor<4x256x8xf32>, tensor<2xi32>) -> tensor<1024x8xf32> // CHECK: %[[RES0:.*]] = "tfl.reshape"(%arg1, %[[CST:.*]]) : (tensor<8x4x256xf32>, tensor<2xi32>) -> tensor<8x1024xf32> - // CHECK: %[[RES1:.*]] = "tfl.batch_matmul"(%arg0, %[[RES0]]) {adj_x = false, adj_y = true, asymmetric_quantize_inputs = false} : (tensor<4x1024xf32>, tensor<8x1024xf32>) -> tensor<4x8xf32> + // CHECK: %[[RES1:.*]] = "tfl.batch_matmul"(%arg0, %[[RES0]]) <{adj_x = false, adj_y = true, asymmetric_quantize_inputs = false}> : (tensor<4x1024xf32>, tensor<8x1024xf32>) -> tensor<4x8xf32> %2 = "tfl.batch_matmul"(%arg0, %1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<4x1024xf32>, tensor<1024x8xf32>) -> tensor<4x8xf32> func.return %2 : tensor<4x8xf32> // CHECK: return %[[RES1]] : tensor<4x8xf32> @@ -3821,7 +3832,7 @@ func.func @FuseTransposeReshapeIntoBatchMatmul(%arg0: tensor<4x1024xf32>, %arg1: // CHECK-LABEL: FuseTransposeAfterBatchMatmul func.func @FuseTransposeAfterBatchMatmul(%arg0: tensor<4x1024xf32>, %arg1: tensor<8x1024xf32>, %arg2: none) -> tensor<8x4xf32> { %cst = arith.constant dense<[1, 0]> : tensor<2xi32> - // CHECK: %[[RES0:.*]] = "tfl.batch_matmul"(%arg1, %arg0) {adj_x = false, adj_y = true, asymmetric_quantize_inputs = false} : (tensor<8x1024xf32>, tensor<4x1024xf32>) -> tensor<8x4xf32> + // CHECK: %[[RES0:.*]] = "tfl.batch_matmul"(%arg1, %arg0) <{adj_x = false, adj_y = true, asymmetric_quantize_inputs = false}> : (tensor<8x1024xf32>, tensor<4x1024xf32>) -> tensor<8x4xf32> %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true, asymmetric_quantize_inputs = false} : (tensor<4x1024xf32>, tensor<8x1024xf32>) -> tensor<4x8xf32> %1 = "tfl.transpose"(%0, %cst) : (tensor<4x8xf32>, tensor<2xi32>) -> tensor<8x4xf32> func.return %1 : tensor<8x4xf32> @@ -3980,7 +3991,7 @@ func.func @FuseExcessBroadcastingOnReshapes(%arg0: tensor<1x8xf32>) -> tensor<1x // CHECK: %cst_0 = arith.constant dense<[1, 1, 1, 128]> : tensor<4xi32> // CHECK: %cst_1 = arith.constant dense<[8, 1]> : tensor<2xi32> // CHECK: %0 = "tfl.reshape"(%arg0, %cst_1) : (tensor<1x8xf32>, tensor<2xi32>) -> tensor<8x1xf32> - // CHECK: %1 = tfl.mul(%0, %cst) {fused_activation_function = "NONE"} : (tensor<8x1xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK: %1 = tfl.mul(%0, %cst) <{fused_activation_function = "NONE"}> : (tensor<8x1xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK: %2 = "tfl.reshape"(%1, %cst_0) : (tensor<8x16xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> // CHECK: return %2 : tensor<1x1x1x128xf32> } @@ -4003,7 +4014,7 @@ func.func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> - // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> // CHECK: return %0 : tensor<3x3xf32> } @@ -4012,7 +4023,7 @@ func.func @broadcast_to_i32_low_dim(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> return %0 : tensor<3x3xi32> // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi32> - // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> // CHECK: return %0 : tensor<3x3xi32> } @@ -4021,7 +4032,7 @@ func.func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> - // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> // CHECK: return %0 : tensor<3x3xf32> } @@ -4030,7 +4041,7 @@ func.func @broadcast_to_i16_low_dim(%arg0: tensor<3xi16>, %arg1: tensor<2xi32>) %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi16>, tensor<2xi32>) -> tensor<3x3xi16> return %0 : tensor<3x3xi16> // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi16> - // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xi16>, tensor<3x3xi16>) -> tensor<3x3xi16> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi16>, tensor<3x3xi16>) -> tensor<3x3xi16> // CHECK: return %0 : tensor<3x3xi16> } @@ -4040,7 +4051,7 @@ func.func @broadcast_to_i32_low_dim_with_unknown_output(%arg0: tensor<3xi32>, %a return %0 : tensor<*xi32> // CHECK: %cst = arith.constant dense<1> : tensor // CHECK: %0 = "tfl.fill"(%arg1, %cst) : (tensor<2xi32>, tensor) -> tensor<*xi32> - // CHECK: %1 = tfl.mul(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32> + // CHECK: %1 = tfl.mul(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32> // CHECK: return %1 : tensor<*xi32> } @@ -4049,7 +4060,7 @@ func.func @broadcast_to_ui32(%arg0: tensor, %arg1: tensor<1xi64>) -> tenso %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor, tensor<1xi64>) -> tensor<10xui32> return %0 : tensor<10xui32> // CHECK: %cst = arith.constant dense<1> : tensor<10xui32> - // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor, tensor<10xui32>) -> tensor<10xui32> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor, tensor<10xui32>) -> tensor<10xui32> // CHECK: return %0 : tensor<10xui32> } @@ -4058,7 +4069,7 @@ func.func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tenso %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> - // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> // CHECK: return %0 : tensor<3x3xf32> } @@ -4067,7 +4078,7 @@ func.func @broadcast_to_i32(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tenso %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> return %0 : tensor<3x3xi32> // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi32> - // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> // CHECK: return %0 : tensor<3x3xi32> } @@ -4077,7 +4088,7 @@ func.func @broadcast_to_i32_with_dynamic_shape_and_output(%arg0: tensor<3xi32>, return %0 : tensor<3x?xi32> // CHECK: %cst = arith.constant dense<1> : tensor // CHECK: %0 = "tfl.fill"(%arg1, %cst) : (tensor<2xi32>, tensor) -> tensor<3x?xi32> - // CHECK: %1 = tfl.mul(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x?xi32>) -> tensor<3x?xi32> + // CHECK: %1 = tfl.mul(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x?xi32>) -> tensor<3x?xi32> // CHECK: return %1 : tensor<3x?xi32> } diff --git a/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir b/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir index c838ead5f031f9..9e2cccaf5158d5 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir @@ -16,9 +16,9 @@ func.func @FuseTransposeFCLhsToBatchMatmul(%arg0: tensor<1024x4xf32>, %arg1: ten %cst_0 = arith.constant dense<[1, 0]> : tensor<2xi32> %cst_1 = "tfl.no_value"() {value} : () -> none %0 = "tfl.transpose"(%arg0, %cst_0) : (tensor<1024x4xf32>, tensor<2xi32>) -> tensor<4x1024xf32> - // CHECK: %[[RES0:.*]] = "tfl.batch_matmul"(%arg1, %arg0) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<8x1024xf32>, tensor<1024x4xf32>) -> tensor<8x4xf32> + // CHECK: %[[RES0:.*]] = "tfl.batch_matmul"(%arg1, %arg0) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<8x1024xf32>, tensor<1024x4xf32>) -> tensor<8x4xf32> %1 = "tfl.fully_connected"(%0, %arg1, %cst_1) {asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x1024xf32>, tensor<8x1024xf32>, none) -> tensor<4x8xf32> - // CHECK: %[[RES1:.*]] = "tfl.batch_matmul"(%[[RES0]], %arg2) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<8x4xf32>, tensor<4x256xf32>) -> tensor<8x256xf32> + // CHECK: %[[RES1:.*]] = "tfl.batch_matmul"(%[[RES0]], %arg2) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<8x4xf32>, tensor<4x256xf32>) -> tensor<8x256xf32> %2 = "tfl.batch_matmul"(%1, %arg2) {adj_x = true, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<4x8xf32>, tensor<4x256xf32>) -> tensor<8x256xf32> func.return %2 : tensor<8x256xf32> // CHECK: return %[[RES1]] : tensor<8x256xf32> @@ -34,7 +34,7 @@ func.func @Batchmatmul2Fullyconnected(%arg0: tensor<4x128x2xf32>) -> (tensor<4x1 // CHECK-SAME: [1.000000e+00, 2.000000e+00] // CHECK-SAME: tensor<1x2xf32> // CHECK: %[[FC_RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONST_WEIGHT]] - // CHECK-SAME: {fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<4x128x2xf32>, tensor<1x2xf32>, none) -> tensor<4x128x1xf32> + // CHECK-SAME: <{fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<4x128x2xf32>, tensor<1x2xf32>, none) -> tensor<4x128x1xf32> // CHECK-NEXT: return %[[FC_RES]] } @@ -48,7 +48,7 @@ func.func @Batchmatmul2FullyconnectedAdjy(%arg0: tensor<4x128x2xf32>) -> (tensor // CHECK-SAME: [1.000000e+00, 2.000000e+00] // CHECK-SAME: tensor<1x2xf32> // CHECK: %[[FC_RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONST_WEIGHT]] - // CHECK-SAME: {fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<4x128x2xf32>, tensor<1x2xf32>, none) -> tensor<4x128x1xf32> + // CHECK-SAME: <{fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<4x128x2xf32>, tensor<1x2xf32>, none) -> tensor<4x128x1xf32> // CHECK-NEXT: return %[[FC_RES]] } @@ -62,7 +62,7 @@ func.func @Batchmatmul2FullyconnectedAdjx(%arg0: tensor<4x2x128xf32>) -> (tensor // CHECK: %[[TRANSPOSED_X:.*]] = "tfl.transpose" // CHECK-SAME: (tensor<4x2x128xf32>, tensor<3xi32>) -> tensor<4x128x2xf32> // CHECK-NEXT: %[[FC_RES:.*]] = "tfl.fully_connected"(%[[TRANSPOSED_X]] - // CHECK-SAME: {fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<4x128x2xf32>, tensor<1x2xf32>, none) -> tensor<4x128x1xf32> + // CHECK-SAME: <{fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<4x128x2xf32>, tensor<1x2xf32>, none) -> tensor<4x128x1xf32> // CHECK-NEXT: return %[[FC_RES]] } @@ -87,7 +87,7 @@ func.func @Batchmatmul2FullyconnectedTransposedY(%arg0: tensor<4x128x2xf32>) -> // CHECK-SAME: [1.000000e+00, 2.000000e+00] // CHECK-SAME: tensor<1x2xf32> // CHECK: %[[FC_RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONST_WEIGHT]] - // CHECK-SAME: {fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<4x128x2xf32>, tensor<1x2xf32>, none) -> tensor<4x128x1xf32> + // CHECK-SAME: <{fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<4x128x2xf32>, tensor<1x2xf32>, none) -> tensor<4x128x1xf32> // CHECK-NEXT: return %[[FC_RES]] } @@ -113,7 +113,7 @@ func.func @Batchmatmul2FullyconnectedQDQ(%arg0: tensor<4x128x2xf32>, %arg1: tens // CHECK: %[[TRANSPOSED_X:.*]] = "tfl.transpose" // CHECK-SAME: (tensor<2x1xf32>, tensor<2xi32>) -> tensor<1x2xf32> // CHECK: %[[FC_RES:.*]] = "tfl.fully_connected"(%arg0, %[[TRANSPOSED_X]] - // CHECK-SAME: {fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<4x128x2xf32>, tensor<1x2xf32>, none) -> tensor<4x128x1xf32> + // CHECK-SAME: <{fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<4x128x2xf32>, tensor<1x2xf32>, none) -> tensor<4x128x1xf32> // CHECK-NEXT: return %[[FC_RES]] } @@ -123,8 +123,8 @@ func.func @BatchmatmulToReduceSumI32(%arg0: tensor<1x16384x257xi32>) -> (tensor< %0 = arith.constant dense<1> : tensor<1x1x16384xi32> %1 = "tfl.batch_matmul"(%0, %arg0) {adj_x = false, adj_y = false} : (tensor<1x1x16384xi32>, tensor<1x16384x257xi32>) -> tensor<1x1x257xi32> func.return %1 : tensor<1x1x257xi32> - // CHECK: %[[CONST_DIM:.*]] = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[RED:.*]] = "tfl.sum"(%arg0, %[[CONST_DIM]]) {keep_dims = true} : (tensor<1x16384x257xi32>, tensor<1xi32>) -> tensor<1x1x257xi32> + // CHECK: %[[CONST_DIM:.*]] = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> + // CHECK: %[[RED:.*]] = "tfl.sum"(%arg0, %[[CONST_DIM]]) <{keep_dims = true}> : (tensor<1x16384x257xi32>, tensor<1xi32>) -> tensor<1x1x257xi32> } // CHECK-LABEL: BatchmatmulToReduceSumF32 @@ -133,6 +133,6 @@ func.func @BatchmatmulToReduceSumF32(%arg0: tensor<1x16384x257xf32>) -> (tensor< %0 = arith.constant dense<1.0> : tensor<1x1x16384xf32> %1 = "tfl.batch_matmul"(%0, %arg0) {adj_x = false, adj_y = false} : (tensor<1x1x16384xf32>, tensor<1x16384x257xf32>) -> tensor<1x1x257xf32> func.return %1 : tensor<1x1x257xf32> - // CHECK: %[[CONST_DIM:.*]] = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[RED:.*]] = "tfl.sum"(%arg0, %[[CONST_DIM]]) {keep_dims = true} : (tensor<1x16384x257xf32>, tensor<1xi32>) -> tensor<1x1x257xf32> + // CHECK: %[[CONST_DIM:.*]] = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> + // CHECK: %[[RED:.*]] = "tfl.sum"(%arg0, %[[CONST_DIM]]) <{keep_dims = true}> : (tensor<1x16384x257xf32>, tensor<1xi32>) -> tensor<1x1x257xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/optimize_no_verify.mlir b/tensorflow/compiler/mlir/lite/tests/optimize_no_verify.mlir index caa4af5efbbc64..d9c81887664a0f 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize_no_verify.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize_no_verify.mlir @@ -26,8 +26,8 @@ func.func @fuseBroadcastMulIntoFullyConnected(%arg0: tensor<1x10368xbf16>) -> te %1 = "tfl.mul"(%0, %cst_2) {fused_activation_function = "NONE"} : (tensor<1x256xbf16>, tensor<32x1x256xbf16>) -> tensor<32x1x256xbf16> func.return %1 : tensor<32x1x256xbf16> -// CHECK: %[[V0:.*]] = "tfl.fully_connected"(%arg0, {{.*}}) {{{.*}}} : (tensor<1x10368xbf16>, tensor<256x10368xbf16>, none) -> tensor<1x256xbf16> -// CHECK: %[[V1:.*]] = tfl.mul(%[[V0]], {{.*}}) {{{.*}}} : (tensor<1x256xbf16>, tensor<32x1x256xbf16>) -> tensor<32x1x256xbf16> +// CHECK: %[[V0:.*]] = "tfl.fully_connected"(%arg0, {{.*}}) <{{{.*}}}> : (tensor<1x10368xbf16>, tensor<256x10368xbf16>, none) -> tensor<1x256xbf16> +// CHECK: %[[V1:.*]] = tfl.mul(%[[V0]], {{.*}}) <{{{.*}}}> : (tensor<1x256xbf16>, tensor<32x1x256xbf16>) -> tensor<32x1x256xbf16> // CHECK: return %[[V1]] : tensor<32x1x256xbf16> } diff --git a/tensorflow/compiler/mlir/lite/tests/post-quantize-dynamic-range.mlir b/tensorflow/compiler/mlir/lite/tests/post-quantize-dynamic-range.mlir index c3cc6aa588a2df..c68fab4762e120 100644 --- a/tensorflow/compiler/mlir/lite/tests/post-quantize-dynamic-range.mlir +++ b/tensorflow/compiler/mlir/lite/tests/post-quantize-dynamic-range.mlir @@ -12,9 +12,9 @@ func.func @PruneUnusedCustomOp(%arg0: tensor<1x1x1x1xf32>) -> tensor<*xf32> attr %custom_3 = "tfl.custom"(%arg0, %dq_w) {custom_code = "CustomTestOp", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<1024x1x1x1xf32>) -> tensor<*xf32> func.return %custom_3 : tensor<*xf32> -// CHECK: %[[q_w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>>, value = dense<127> : tensor<1024x1x1x1xi8>} : () -> tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>> +// CHECK: %[[q_w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>>, value = dense<127> : tensor<1024x1x1x1xi8>}> : () -> tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>> // CHECK: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w:.*]]) : (tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>>) -> tensor<1024x1x1x1xf32> -// CHECK: %[[custom_3:.*]] = "tfl.custom"(%arg0, %[[dq_w:.*]]) {custom_code = "CustomTestOp", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<1024x1x1x1xf32>) -> tensor<*xf32> +// CHECK: %[[custom_3:.*]] = "tfl.custom"(%arg0, %[[dq_w:.*]]) <{custom_code = "CustomTestOp", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<1024x1x1x1xf32>) -> tensor<*xf32> // CHECK: return %[[custom_3:.*]] } @@ -27,11 +27,11 @@ func.func @NotPruneUnusedCustomOp(%arg0: tensor<1x1x1x1xf32>) -> tensor<*xf32> a %custom_3 = "tfl.custom"(%arg0, %dq_w) {custom_code = "CustomTestOp2", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<1024x1x1x1xf32>) -> tensor<*xf32> func.return %custom_3 : tensor<*xf32> -// CHECK: %[[q_w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>>, value = dense<127> : tensor<1024x1x1x1xi8>} : () -> tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>> +// CHECK: %[[q_w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>>, value = dense<127> : tensor<1024x1x1x1xi8>}> : () -> tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>> // CHECK: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w:.*]]) : (tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>>) -> tensor<1024x1x1x1xf32> -// CHECK: %[[custom_1:.*]] = "tfl.custom"(%arg0, %[[dq_w:.*]]) {custom_code = "CustomTestOp2", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<1024x1x1x1xf32>) -> tensor<*xf32> -// CHECK: %[[custom_2:.*]] = "tfl.custom"(%arg0, %[[dq_w:.*]]) {custom_code = "CustomTestOp2", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<1024x1x1x1xf32>) -> tensor<*xf32> -// CHECK: %[[custom_3:.*]] = "tfl.custom"(%arg0, %[[dq_w:.*]]) {custom_code = "CustomTestOp2", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<1024x1x1x1xf32>) -> tensor<*xf32> +// CHECK: %[[custom_1:.*]] = "tfl.custom"(%arg0, %[[dq_w:.*]]) <{custom_code = "CustomTestOp2", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<1024x1x1x1xf32>) -> tensor<*xf32> +// CHECK: %[[custom_2:.*]] = "tfl.custom"(%arg0, %[[dq_w:.*]]) <{custom_code = "CustomTestOp2", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<1024x1x1x1xf32>) -> tensor<*xf32> +// CHECK: %[[custom_3:.*]] = "tfl.custom"(%arg0, %[[dq_w:.*]]) <{custom_code = "CustomTestOp2", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<1024x1x1x1xf32>) -> tensor<*xf32> // CHECK: return %[[custom_3:.*]] } @@ -46,22 +46,22 @@ func.func @PruneQuantizedCustomOp(%arg0: tensor<1x1x1x1xf32>) -> tensor<*xf32> a func.return %custom : tensor<*xf32> // CHECK: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<1024x1x1x1xf32> -// CHECK: %[[custom:.*]] = "tfl.custom"(%arg0, %[[w:.*]]) {custom_code = "CustomTestOp", custom_option = #tfl} +// CHECK: %[[custom:.*]] = "tfl.custom"(%arg0, %[[w:.*]]) <{custom_code = "CustomTestOp", custom_option = #tfl}> // CHECK: return %[[custom:.*]] -// NotPrune: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>> +// NotPrune: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>> // NotPrune: %[[dq_w:.*]] = "tfl.dequantize"(%[[w:.*]]) : (tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>>) -> tensor<1024x1x1x1xf32> -// NotPrune: %[[custom:.*]] = "tfl.custom"(%arg0, %[[dq_w:.*]]) {custom_code = "CustomTestOp", custom_option = #tfl} -// NotPrune: %[[custom_1:.*]] = "tfl.custom"(%arg0, %[[w:.*]]) {custom_code = "CustomTestOp", custom_option = #tfl} -// NotPrune: %[[custom_2:.*]] = "tfl.custom"(%arg0, %[[w:.*]]) {custom_code = "CustomTestOp", custom_option = #tfl} +// NotPrune: %[[custom:.*]] = "tfl.custom"(%arg0, %[[dq_w:.*]]) <{custom_code = "CustomTestOp", custom_option = #tfl}> +// NotPrune: %[[custom_1:.*]] = "tfl.custom"(%arg0, %[[w:.*]]) <{custom_code = "CustomTestOp", custom_option = #tfl}> +// NotPrune: %[[custom_2:.*]] = "tfl.custom"(%arg0, %[[w:.*]]) <{custom_code = "CustomTestOp", custom_option = #tfl}> -// NoSideEffect: %[[q_w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>> -// NoSideEffect: %[[custom:.*]] = "tfl.custom"(%arg0, %[[q_w:.*]]) {custom_code = "CustomTestOp", custom_option = #tfl} +// NoSideEffect: %[[q_w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>> +// NoSideEffect: %[[custom:.*]] = "tfl.custom"(%arg0, %[[q_w:.*]]) <{custom_code = "CustomTestOp", custom_option = #tfl} // NoSideEffect: return %[[custom:.*]] -// NoSideEffectWeightOnly: %[[q_w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>> +// NoSideEffectWeightOnly: %[[q_w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>> // NoSideEffectWeightOnly: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w:.*]]) : (tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>>) -> tensor<1024x1x1x1xf32> -// NoSideEffectWeightOnly: %[[custom:.*]] = "tfl.custom"(%arg0, %[[dq_w:.*]]) {custom_code = "CustomTestOp", custom_option = #tfl} +// NoSideEffectWeightOnly: %[[custom:.*]] = "tfl.custom"(%arg0, %[[dq_w:.*]]) <{custom_code = "CustomTestOp", custom_option = #tfl}> // NoSideEffectWeightOnly: return %[[custom:.*]] } @@ -80,16 +80,16 @@ func.func @QuantizeCustomOp(%arg0: tensor<1x1x1x1xf32>) -> (tensor<*xf32>, tenso // CHECK: %[[w_1:.*]] = arith.constant dense<1.270000e+02> : tensor<4096x1x1x1xf32> // CHECK: %[[w_2:.*]] = arith.constant dense<1.270000e+02> : tensor<128x1x1x1xf32> // CHECK: %[[b:.*]] = arith.constant dense<1.270000e+02> : tensor<2048x1x1x1xf32> -// CHECK: %[[custom_1:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) {custom_code = "CustomTestOp", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> -// CHECK: %[[custom_2:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) {custom_code = "CustomTestOp2", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> -// CHECK: %[[custom_3:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) {custom_code = "CustomTestOp3", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> +// CHECK: %[[custom_1:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) <{custom_code = "CustomTestOp", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> +// CHECK: %[[custom_2:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) <{custom_code = "CustomTestOp2", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> +// CHECK: %[[custom_3:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) <{custom_code = "CustomTestOp3", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> // CHECK: return %[[custom_1:.*]], %[[custom_2:.*]], %[[custom_3:.*]] // CustomOpWeightOnly: %[[w_1:.*]] = arith.constant dense<1.270000e+02> : tensor<4096x1x1x1xf32> -// CustomOpWeightOnly: %[[q_w1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<4096x1x1x1x!quant.uniform:f32, 1.000000e+00>> +// CustomOpWeightOnly: %[[q_w1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<4096x1x1x1x!quant.uniform:f32, 1.000000e+00>> // CustomOpWeightOnly: %[[dq_w1:.*]] = "tfl.dequantize"(%[[q_w1]]) : (tensor<4096x1x1x1x!quant.uniform:f32, 1.000000e+00>>) -> tensor<4096x1x1x1xf32> // CustomOpWeightOnly: %[[w_2:.*]] = arith.constant dense<1.270000e+02> : tensor<128x1x1x1xf32> -// CustomOpWeightOnly: %[[q_b:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2048x1x1x1x!quant.uniform:f32, 1.000000e+00>> +// CustomOpWeightOnly: %[[q_b:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<2048x1x1x1x!quant.uniform:f32, 1.000000e+00>> // CustomOpWeightOnly: %[[dq_b:.*]] = "tfl.dequantize"(%[[q_b]]) : (tensor<2048x1x1x1x!quant.uniform:f32, 1.000000e+00>>) -> tensor<2048x1x1x1xf32> // CustomOpWeightOnly: %[[custom_1:.*]] = "tfl.custom"(%arg0, %[[dq_w1]], %[[w_2]], %[[dq_b]]) {custom_code = "CustomTestOp", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> // CustomOpWeightOnly: %[[custom_2:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) {custom_code = "CustomTestOp2", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir index de372f38daf906..9e34c1bd7bdbac 100644 --- a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir @@ -20,8 +20,8 @@ func.func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor) -> (tensor<2xf // CHECK-NEXT: %[[split:.*]]:4 = "tfl.split"(%arg1, %arg0) // CHECK-NEXT: return %[[split]]#0, %[[split]]#1 -// QDQ-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<4x!quant.uniform>} : (tensor<4xf32>) -> tensor<4x!quant.uniform> -// QDQ-NEXT: %[[split:.*]]:4 = "tfl.split"(%arg1, %[[q]]) {num_splits = 4 : i32} : (tensor, tensor<4x!quant.uniform>) -> (tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>) +// QDQ-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<4x!quant.uniform>}> : (tensor<4xf32>) -> tensor<4x!quant.uniform> +// QDQ-NEXT: %[[split:.*]]:4 = "tfl.split"(%arg1, %[[q]]) <{num_splits = 4 : i32}> : (tensor, tensor<4x!quant.uniform>) -> (tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>) // QDQ-NEXT: %[[out1:.*]] = "tfl.dequantize"(%[[split]]#0) : (tensor<2x!quant.uniform>) -> tensor<2xf32> // QDQ-NEXT: %[[out2:.*]] = "tfl.dequantize"(%[[split]]#1) : (tensor<2x!quant.uniform>) -> tensor<2xf32> // QDQ-NEXT: return %[[out1]], %[[out2]] : tensor<2xf32>, tensor<2xf32> @@ -37,8 +37,8 @@ func.func @RemoveTrival(%arg0: tensor<384x512x!quant.uniform>, // CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"{{.*}} -> tensor<384x128x!quant.uniform> // CHECK-NEXT: return %[[fc]] -// QDQ-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<384x512x!quant.uniform>, tensor<128x512x!quant.uniform:f32, 1.000000e+00>>, none) -> tensor<384x128x!quant.uniform> -// QDQ-NEXT: %[[q:.*]] = "tfl.quantize"(%[[fc]]) {qtype = tensor<384x128x!quant.uniform>} : (tensor<384x128x!quant.uniform>) -> tensor<384x128x!quant.uniform> +// QDQ-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %arg2) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<384x512x!quant.uniform>, tensor<128x512x!quant.uniform:f32, 1.000000e+00>>, none) -> tensor<384x128x!quant.uniform> +// QDQ-NEXT: %[[q:.*]] = "tfl.quantize"(%[[fc]]) <{qtype = tensor<384x128x!quant.uniform>}> : (tensor<384x128x!quant.uniform>) -> tensor<384x128x!quant.uniform> // QDQ-NEXT: return %[[q]] : tensor<384x128x!quant.uniform> } @@ -64,11 +64,11 @@ func.func @main2(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x4xf // CHECK: func @main(%arg0: tensor<1x224x224x3x!quant.uniform>) // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<[1, 401408]> : tensor<2xi32> -// CHECK-NEXT: %[[q_cst_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} -// CHECK-NEXT: %[[q_cst_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} -// CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[q_cst_0]], %[[q_cst_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} +// CHECK-NEXT: %[[q_cst_0:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}> +// CHECK-NEXT: %[[q_cst_1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>}> +// CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[q_cst_0]], %[[q_cst_1]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> // CHECK-NEXT: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[cst]]) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -// CHECK-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) +// CHECK-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) <{beta = 1.000000e+00 : f32}> : (tensor<1x401408x!quant.uniform>) // CHECK-NEXT: return %[[softmax]] : tensor<1x401408x!quant.uniform> // CHECK-NEXT:} @@ -81,7 +81,7 @@ func.func @main2(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x4xf func.func @HandleReturnedDequantizeWithAnotherUse(%arg0: tensor<128x16xf32>) -> (tensor<128x16xf32>, tensor<128xi32>) { // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1> : tensor %cst = arith.constant dense<1> : tensor -// CHECK-NEXT: %[[softmax:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<128x16xf32>) -> tensor<128x16xf32> +// CHECK-NEXT: %[[softmax:.*]] = "tfl.softmax"(%arg0) <{beta = 1.000000e+00 : f32}> : (tensor<128x16xf32>) -> tensor<128x16xf32> %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<128x16xf32>) -> tensor<128x16xf32> %1 = "tfl.quantize"(%0) {qtype = tensor<128x16x!quant.uniform>, volatile} : (tensor<128x16xf32>) -> tensor<128x16x!quant.uniform> %2 = "tfl.dequantize"(%1) : (tensor<128x16x!quant.uniform>) -> tensor<128x16xf32> @@ -145,11 +145,11 @@ func.func @RemoveLeadingQdq(%arg0: tensor<4xf32>, %arg1: tensor) -> (tensor func.return %4 : tensor<2xf32> // CHECK-NEXT: %[[dequant:.*]] = "tfl.dequantize"(%arg0) : (tensor<4x!quant.uniform>) -> tensor<4xf32> -// CHECK-NEXT: %[[split:.*]]:4 = "tfl.split"(%arg1, %[[dequant]]) {num_splits = 4 : i32} : (tensor, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -// CHECK-NEXT: %[[quant:.*]] = "tfl.quantize"(%[[split]]#0) {qtype = tensor<2x!quant.uniform>, volatile} : (tensor<2xf32>) -> tensor<2x!quant.uniform> +// CHECK-NEXT: %[[split:.*]]:4 = "tfl.split"(%arg1, %[[dequant]]) <{num_splits = 4 : i32}> : (tensor, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) +// CHECK-NEXT: %[[quant:.*]] = "tfl.quantize"(%[[split]]#0) <{qtype = tensor<2x!quant.uniform>}> {volatile} : (tensor<2xf32>) -> tensor<2x!quant.uniform> // CHECK-NEXT: return %[[quant]] : tensor<2x!quant.uniform> -// QDQ-NEXT: %[[split:.*]]:4 = "tfl.split"(%arg1, %arg0) {num_splits = 4 : i32} : (tensor, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) +// QDQ-NEXT: %[[split:.*]]:4 = "tfl.split"(%arg1, %arg0) <{num_splits = 4 : i32}> : (tensor, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) // QDQ-NEXT: return %[[split]]#0 : tensor<2xf32> } @@ -166,7 +166,7 @@ func.func @FoldTranspose(%arg0: tensor<1x10x20x3xf32>) -> tensor<1x20x40x16xf32> return %5 : tensor<1x20x40x16xf32> // CHECK-NOT: "tfl.transpose" - // CHECK: "tfl.pseudo_qconst"() {qtype = tensor<16x3x3x3x!quant.uniform:f32, 0.047244094488188976>>, value = dense<"0x03030402FD010302010103FE0301020001010001FD02030101FE0400020100FDFEFD01FC01FF02FEFCFE000303FCFE00FF0301FF04010303FF0402FE01FF01000002FD03FD03FC020202FE0204FD03FF01FFFD03FEFE010003FFFF010103FD00FCFEFE020300FFFE02FD03010402040201010401FCFDFDFF0102FE010003FD00FD02FF03FF000201FF00FD0204FD010102FFFF02020003000102FF0002FF0204040300FEFFFEFDFCFC000000000201020000010001FF00FFFF01FF03FE0003FF03FFFEFE03FE03FF0000FE0303FE0002FF01FF01FF04FDFD01FD020101FDFE0101030303020203030301FD010104FD000103FC03FF02FE020402000002FDFF0103FF03010102FDFE02FF00FE01FD02FEFE0002FD02FE0203FFFFFC01FC0102FE04FCFEFC00FCFCFF03000301FFFE03030100030001000302FC01FD0000FD010101FC01020201FDFFFE02FE00FE0201020003040203010100010404FE00FDFE04FE0401FEFDFDFD00FD04FEFCFF03FFFDFF01FF04030403020200020303FF00FF03FD000104FEFD04FCFCFDFE02FF02000003FF00FF030002FDFEFD030300030401000104FCFE030103FC01FD00FC03FE"> : tensor<16x3x3x3xi8>} : () -> tensor<16x3x3x3x!quant.uniform:f32, 0.047244094488188976>> + // CHECK: "tfl.pseudo_qconst"() <{qtype = tensor<16x3x3x3x!quant.uniform:f32, 0.047244094488188976>>, value = dense<"0x03030402FD010302010103FE0301020001010001FD02030101FE0400020100FDFEFD01FC01FF02FEFCFE000303FCFE00FF0301FF04010303FF0402FE01FF01000002FD03FD03FC020202FE0204FD03FF01FFFD03FEFE010003FFFF010103FD00FCFEFE020300FFFE02FD03010402040201010401FCFDFDFF0102FE010003FD00FD02FF03FF000201FF00FD0204FD010102FFFF02020003000102FF0002FF0204040300FEFFFEFDFCFC000000000201020000010001FF00FFFF01FF03FE0003FF03FFFEFE03FE03FF0000FE0303FE0002FF01FF01FF04FDFD01FD020101FDFE0101030303020203030301FD010104FD000103FC03FF02FE020402000002FDFF0103FF03010102FDFE02FF00FE01FD02FEFE0002FD02FE0203FFFFFC01FC0102FE04FCFEFC00FCFCFF03000301FFFE03030100030001000302FC01FD0000FD010101FC01020201FDFFFE02FE00FE0201020003040203010100010404FE00FDFE04FE0401FEFDFDFD00FD04FEFCFF03FFFDFF01FF04030403020200020303FF00FF03FD000104FEFD04FCFCFDFE02FF02000003FF00FF030002FDFEFD030300030401000104FCFE030103FC01FD00FC03FE"> : tensor<16x3x3x3xi8>}> : () -> tensor<16x3x3x3x!quant.uniform:f32, 0.047244094488188976>> // CHECK-NEXT: "tfl.transpose_conv" } @@ -178,6 +178,6 @@ func.func @FoldReshape(%arg0: tensor<4xi32>, %arg1: tensor<1x48x80x16x!quant.uni %2 = "tfl.transpose_conv"(%arg0, %1, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<1x2x2x16x!quant.uniform:f32, 0.022395913056501255>>, tensor<1x48x80x16x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x96x160x1x!quant.uniform> return %2 : tensor<1x96x160x1x!quant.uniform> // CHECK-NOT: "tfl.reshape" - // CHECK{LITERAL}: "tfl.pseudo_qconst"() {qtype = tensor<1x2x2x16x!quant.uniform:f32, 0.022395913056501255>>, value = dense<[[[[12, -60, -51, -59, -62, 33, 53, 17, -31, 50, 27, 7, -19, -34, -14, -26], [47, -84, -32, -36, -102, -8, -8, 35, -33, 59, 95, 40, -25, -30, -55, 25]], [[4, -41, -61, 12, -23, 48, 40, 15, -39, 52, 81, -62, -24, 17, -7, -52], [40, -70, -45, 32, -43, 2, -30, 34, -35, 58, 77, -28, -30, 37, -47, -5]]]]> : tensor<1x2x2x16xi8>} : () -> tensor<1x2x2x16x!quant.uniform:f32, 0.022395913056501255>> + // CHECK{LITERAL}: "tfl.pseudo_qconst"() <{qtype = tensor<1x2x2x16x!quant.uniform:f32, 0.022395913056501255>>, value = dense<[[[[12, -60, -51, -59, -62, 33, 53, 17, -31, 50, 27, 7, -19, -34, -14, -26], [47, -84, -32, -36, -102, -8, -8, 35, -33, 59, 95, 40, -25, -30, -55, 25]], [[4, -41, -61, 12, -23, 48, 40, 15, -39, 52, 81, -62, -24, 17, -7, -52], [40, -70, -45, 32, -43, 2, -30, 34, -35, 58, 77, -28, -30, 37, -47, -5]]]]> : tensor<1x2x2x16xi8>}> : () -> tensor<1x2x2x16x!quant.uniform:f32, 0.022395913056501255>> // CHECK-NEXT: "tfl.transpose_conv" } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir index 27d98c7599c93d..d7ce1a065a8540 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir @@ -46,7 +46,7 @@ func.func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3 // CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_1]], [[VAL_6]]) : (tensor<3x4xf32>, tensor<2xi32>) -> tensor<4x3xf32> // CHECK: [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> // CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<1x3xf32>, tensor<2xi32>) -> tensor<3x1xf32> -// CHECK-DAG: [[VAL_10:%.*]] = "tfl.no_value"() {value} : () -> none +// CHECK-DAG: [[VAL_10:%.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK-DAG: [[VAL_11:%.*]] = arith.constant dense<0> : tensor<2xi64> // CHECK-DAG: [[VAL_12:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi64> // CHECK: [[VAL_13:%.*]] = "tf.Slice"([[VAL_7]], [[VAL_11]], [[VAL_12]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x0xf32> @@ -85,8 +85,9 @@ func.func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3 // CHECK-DAG: [[VAL_46:%.*]] = arith.constant dense<0.000000e+00> : tensor<3xf32> // CHECK-DAG: [[VAL_47:%.*]] = arith.constant dense<0.000000e+00> : tensor<1x3xf32> // CHECK-DAG: [[VAL_48:%.*]] = arith.constant dense<0.000000e+00> : tensor<1x1xf32> -// CHECK: [[VAL_49:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_16]], [[VAL_19]], [[VAL_13]], [[VAL_22]], [[VAL_28]], [[VAL_31]], [[VAL_25]], [[VAL_34]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_40]], [[VAL_41]], [[VAL_37]], [[VAL_42]], [[VAL_45]], [[VAL_46]], [[VAL_47]], [[VAL_48]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]]) ({ -// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x?xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<3x1xf32>, tensor<3xf32>, tensor<1x3xf32>, tensor<1x1xf32>, none, none, none, none) -> tensor<1x3xf32> +// CHECK: [[VAL_49:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_16]], [[VAL_19]], [[VAL_13]], [[VAL_22]], [[VAL_28]], [[VAL_31]], [[VAL_25]], [[VAL_34]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_40]], [[VAL_41]], [[VAL_37]], [[VAL_42]], [[VAL_45]], [[VAL_46]], [[VAL_47]], [[VAL_48]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]]) +// CHECK-SAME: <{cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32}> ({ +// CHECK: }) : (tensor<1x?xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<3x1xf32>, tensor<3xf32>, tensor<1x3xf32>, tensor<1x1xf32>, none, none, none, none) -> tensor<1x3xf32> // CHECK: [[VAL_50:%.*]] = tensor.cast [[VAL_51:%.*]] : tensor<1x3xf32> to tensor<1x?xf32> // CHECK: return [[VAL_50]] : tensor<1x?xf32> @@ -98,7 +99,7 @@ func.func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3 // CHECK: [[VAL_53:%.*]] = "tf.Transpose"([[VAL_1]], [[VAL_52]]) : (tensor<3x4xf32>, tensor<2xi32>) -> tensor<4x3xf32> // CHECK: [[VAL_54:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> // CHECK: [[VAL_55:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_54]]) : (tensor<1x3xf32>, tensor<2xi32>) -> tensor<3x1xf32> -// CHECK-DAG: [[VAL_56:%.*]] = "tfl.no_value"() {value} : () -> none +// CHECK-DAG: [[VAL_56:%.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK-DAG: [[VAL_57:%.*]] = arith.constant dense<0> : tensor<2xi64> // CHECK-DAG: [[VAL_58:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi64> // CHECK: [[VAL_59:%.*]] = "tf.Slice"([[VAL_53]], [[VAL_57]], [[VAL_58]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x0xf32> @@ -145,8 +146,9 @@ func.func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3 // CHECK: [[VAL_100:%.*]] = "tf.Slice"([[VAL_5]], [[VAL_98]], [[VAL_99]]) : (tensor<2xf32>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xf32> // CHECK-DAG: [[VAL_101:%.*]] = arith.constant dense<0.000000e+00> : tensor<1xf32> // CHECK-DAG: [[VAL_102:%.*]] = arith.constant dense<0.000000e+00> : tensor<1xf32> -// CHECK: [[VAL_103:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_62]], [[VAL_65]], [[VAL_59]], [[VAL_68]], [[VAL_74]], [[VAL_77]], [[VAL_71]], [[VAL_80]], [[VAL_56]], [[VAL_56]], [[VAL_56]], [[VAL_86]], [[VAL_87]], [[VAL_83]], [[VAL_88]], [[VAL_91]], [[VAL_92]], [[VAL_93]], [[VAL_94]], [[VAL_100]], [[VAL_101]], [[VAL_97]], [[VAL_102]]) ({ -// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x?xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<3x1xf32>, tensor<3xf32>, tensor<1x3xf32>, tensor<1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x3xf32> +// CHECK: [[VAL_103:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_62]], [[VAL_65]], [[VAL_59]], [[VAL_68]], [[VAL_74]], [[VAL_77]], [[VAL_71]], [[VAL_80]], [[VAL_56]], [[VAL_56]], [[VAL_56]], [[VAL_86]], [[VAL_87]], [[VAL_83]], [[VAL_88]], [[VAL_91]], [[VAL_92]], [[VAL_93]], [[VAL_94]], [[VAL_100]], [[VAL_101]], [[VAL_97]], [[VAL_102]]) +// CHECK-SAME: <{cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32}> ({ +// CHECK: }) : (tensor<1x?xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<3x1xf32>, tensor<3xf32>, tensor<1x3xf32>, tensor<1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x3xf32> // CHECK: [[VAL_104:%.*]] = tensor.cast [[VAL_105:%.*]] : tensor<1x3xf32> to tensor<1x?xf32> // CHECK: return [[VAL_104]] : tensor<1x?xf32> } @@ -203,8 +205,8 @@ func.func @inference_standard_lstm_time_major(%arg0: tensor, %arg1: t // CHECK-DAG: [[VAL_16:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: [[VAL_17:%.*]] = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -// CHECK: [[VAL_19:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor +// CHECK: [[VAL_19:%.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true}> : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor // CHECK-DAG: [[VAL_21:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -257,8 +259,8 @@ func.func @inference_standard_indy_lstm_time_major(%arg0: tensor<8x8x8xf32>, %ar // CHECK-DAG: [[VAL_28:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -// CHECK: [[VAL_31:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK: [[VAL_31:%.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true}> : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> // CHECK-DAG: [[VAL_33:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -299,8 +301,8 @@ func.func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg // CHECK-DAG: [[VAL_16:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: [[VAL_17:%.*]] = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -// CHECK: [[VAL_19:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK: [[VAL_19:%.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}> : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> // CHECK-DAG: [[VAL_21:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -354,8 +356,8 @@ func.func @inference_standard_indy_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, // CHECK-DAG: [[VAL_28:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -// CHECK: [[VAL_31:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK: [[VAL_31:%.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}> : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> // CHECK-DAG: [[VAL_33:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -398,8 +400,8 @@ func.func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: [[VAL_19:%.*]] = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -// CHECK: [[VAL_21:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor +// CHECK: [[VAL_21:%.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true}> : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor // CHECK-DAG: [[VAL_23:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_24:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_25:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -455,8 +457,8 @@ func.func @inference_standard_indy_lstm_time_major_go_backwards(%arg0: tensor<8x // CHECK-DAG: [[VAL_28:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -// CHECK: [[VAL_31:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_41]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK: [[VAL_31:%.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_41]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true}> : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> // CHECK-DAG: [[VAL_33:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -499,8 +501,8 @@ func.func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8 // CHECK-DAG: [[VAL_18:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: [[VAL_19:%.*]] = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -// CHECK: [[VAL_21:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK: [[VAL_21:%.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}> : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> // CHECK-DAG: [[VAL_23:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_24:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_25:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -556,8 +558,8 @@ func.func @inference_standard_indy_lstm_non_time_major_go_backwards(%arg0: tenso // CHECK-DAG: [[VAL_28:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -// CHECK: [[VAL_31:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_41]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK: [[VAL_31:%.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_41]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}> : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> // CHECK-DAG: [[VAL_33:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -605,8 +607,8 @@ func.func @inference_standard_lstm_time_major_can_fuse(%arg0: tensor, // CHECK-DAG: [[VAL_16:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: [[VAL_17:%.*]] = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -// CHECK: [[VAL_19:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor +// CHECK: [[VAL_19:%.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true}> : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor // CHECK-DAG: [[VAL_21:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -655,8 +657,8 @@ func.func @inference_standard_lstm_time_major_can_fuse_last_output(%arg0: tensor // CHECK-DAG: [[VAL_16:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: [[VAL_17:%.*]] = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -// CHECK: [[VAL_19:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor +// CHECK: [[VAL_19:%.*]] = "tfl.no_value"() <{value}> : () -> none +// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true}> : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor // CHECK-DAG: [[VAL_21:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -811,7 +813,7 @@ func.func @tflite_custom_nms(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91x // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x100x4xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x100x91xf32>, // CHECK-SAME: %[[VAL_2:.*]]: tensor<100x4xf32>) -> (tensor, tensor, tensor, tensor) attributes {tf._implements = "TFLite_Detection_PostProcess", tf._reference = "mlir"} { -// CHECK: %[[VAL_3:.*]]:4 = "tfl.custom"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {custom_code = "TFLite_Detection_PostProcess", custom_option = #tfl} : (tensor<1x100x4xf32>, tensor<1x100x91xf32>, tensor<100x4xf32>) -> (tensor, tensor, tensor, tensor) +// CHECK: %[[VAL_3:.*]]:4 = "tfl.custom"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{custom_code = "TFLite_Detection_PostProcess", custom_option = #tfl}> : (tensor<1x100x4xf32>, tensor<1x100x91xf32>, tensor<100x4xf32>) -> (tensor, tensor, tensor, tensor) // CHECK: return %[[VAL_3]]#0, %[[VAL_3]]#1, %[[VAL_3]]#2, %[[VAL_3]]#3 : tensor, tensor, tensor, tensor // CHECK: } } @@ -861,7 +863,7 @@ func.func @max_unpooling_2d(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi3 // CHECK-LABEL: func @max_unpooling_2d( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x2x1xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = "MaxUnpooling2D"} { -// CHECK-NEXT: %[[VAL_2:.*]] = "tfl.custom"(%[[VAL_0]], %[[VAL_1]]) {custom_code = "MaxUnpooling2D", custom_option = #tfl} : (tensor<1x1x2x1xf32>, tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> +// CHECK-NEXT: %[[VAL_2:.*]] = "tfl.custom"(%[[VAL_0]], %[[VAL_1]]) <{custom_code = "MaxUnpooling2D", custom_option = #tfl}> : (tensor<1x1x2x1xf32>, tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> // CHECK-NEXT: return %[[VAL_2]] : tensor<1x2x4x1xf32> // CHECK-NEXT: } } @@ -978,7 +980,7 @@ func.func private @__inference_interpolate_bilinear(%arg0: tensor<2x4x4x1xf32>, // CHECK-LABEL: func private @__inference_dense_image_warp( // CHECK-SAME: %arg0: tensor<2x4x4x1xf32>, // CHECK-SAME: %arg1: tensor<2x4x4x2xf32>) -> tensor<2x4x4x1xf32> attributes {tf._implements = "DenseImageWarp"} { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "DenseImageWarp", custom_option = #tfl} : (tensor<2x4x4x1xf32>, tensor<2x4x4x2xf32>) -> tensor<2x4x4x1xf32> +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) <{custom_code = "DenseImageWarp", custom_option = #tfl}> : (tensor<2x4x4x1xf32>, tensor<2x4x4x2xf32>) -> tensor<2x4x4x1xf32> // CHECK-NEXT: return %0 : tensor<2x4x4x1xf32> // CHECK-NEXT: } } @@ -1014,7 +1016,7 @@ func.func private @dense_image_warp_invalid_output_type(%arg0: tensor<2x4x4x1xf3 // ----- module { -func.func @my_composite_op_150(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) attributes {tf._implements = #tf_type.func<@my_composite_op, {example_option = 10 : i64, tfl_fusable_op = true}>} { +func.func @my_composite_op_150(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) attributes {tf._implements = #tf_type.func<@my_composite_op, {example_option = 10 : i64, example_str = "value 1.01", tfl_fusable_op = true}>} { %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<*xf32> %1 = "tf.Identity"(%0) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> %2 = "tf.Mul"(%0, %arg2) {device = ""} : (tensor<*xf32>, tensor<4x4xf32>) -> tensor<*xf32> @@ -1022,8 +1024,8 @@ func.func @my_composite_op_150(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, % func.return %1, %3 : tensor<*xf32>, tensor<*xf32> } -// CHECK-LABEL: func @my_composite_op_150(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) attributes {tf._implements = #tf_type.func<@my_composite_op, {example_option = 10 : i64, tfl_fusable_op = true}>} { -// CHECK-NEXT: %0:2 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "my_composite_op", custom_option = #tfl} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) +// CHECK-LABEL: func @my_composite_op_150(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) attributes {tf._implements = #tf_type.func<@my_composite_op, {example_option = 10 : i64, example_str = "value 1.01", tfl_fusable_op = true}>} { +// CHECK-NEXT: %0:2 = "tfl.custom"(%arg0, %arg1, %arg2) <{custom_code = "my_composite_op", custom_option = #tfl}> : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) // CHECK-NEXT: return %0#0, %0#1 : tensor<*xf32>, tensor<*xf32> // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir index b35355524127dc..baa41bc47b1d1f 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir @@ -20,32 +20,32 @@ func.func @QuantizeConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x64 // CHECK-DAG: %[[b:.*]] = arith.constant dense<-1.23697901> : tensor<64xf32> // CHECK-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<64x3x3x3xf32> -// CHECK: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<64x3x3x3x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00, +// CHECK: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<64x3x3x3x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00, // CHECK: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) -// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w]], %[[b]]) { +// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w]], %[[b]]) <{ // CHECK-NOT: asymmetric_quantize_inputs = true // CHECK-SAME: dilation_h_factor = 1 : i32 // CHECK: return %[[conv:.*]] // PerTensor-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<64x3x3x3xf32> // PerTensor-DAG: %[[b:.*]] = arith.constant dense<-1.23697901> : tensor<64xf32> -// PerTensor: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>} +// PerTensor: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>}> // PerTensor: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>) -> tensor<64x3x3x3xf32> -// PerTensor: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w]], %[[b]]) { +// PerTensor: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w]], %[[b]]) <{ // PerTensor-NOT: asymmetric_quantize_inputs = true // PerTensor-SAME: dilation_h_factor = 1 : i32 // PerTensor: return %[[conv:.*]] // MinElement-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<64x3x3x3xf32> // MinElement-DAG: %[[b:.*]] = arith.constant dense<-1.23697901> : tensor<64xf32> -// MinElement: %[[conv:.*]]= "tfl.conv_2d"(%arg0, %[[w]], %[[b]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3xf32>, tensor<64x3x3x3xf32>, tensor<64xf32>) -> tensor<1x112x112x64xf32> +// MinElement: %[[conv:.*]]= "tfl.conv_2d"(%arg0, %[[w]], %[[b]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x224x224x3xf32>, tensor<64x3x3x3xf32>, tensor<64xf32>) -> tensor<1x112x112x64xf32> // MinElement: return %[[conv:.*]] // Float16-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<64x3x3x3xf16> // Float16-DAG: %[[b:.*]] = arith.constant dense<-1.237300e+00> : tensor<64xf16> // Float16: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<64x3x3x3xf16>) -> tensor<64x3x3x3xf32> // Float16: %[[dq_b:.*]] = "tfl.dequantize"(%[[b]]) : (tensor<64xf16>) -> tensor<64xf32> -// Float16: %[[conv:.*]]= "tfl.conv_2d"(%arg0, %[[dq_w]], %[[dq_b]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3xf32>, tensor<64x3x3x3xf32>, tensor<64xf32>) -> tensor<1x112x112x64xf32> +// Float16: %[[conv:.*]]= "tfl.conv_2d"(%arg0, %[[dq_w]], %[[dq_b]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x224x224x3xf32>, tensor<64x3x3x3xf32>, tensor<64xf32>) -> tensor<1x112x112x64xf32> // Float16: return %[[conv:.*]] } @@ -63,32 +63,32 @@ func.func @QuantizeDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1 // CHECK-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<64x3x3x3xf32> // CHECK-DAG: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<64xf32> -// CHECK: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<64x3x3x3x!quant.uniform:f32:3, {1.000000e+00,1.000000e+00,1.000000e+00} +// CHECK: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<64x3x3x3x!quant.uniform:f32:3, {1.000000e+00,1.000000e+00,1.000000e+00}> // CHECK: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) -// CHECK: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq_w]], %[[b]]) { +// CHECK: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq_w]], %[[b]]) <{ // CHECK-NOT: asymmetric_quantize_inputs = true // CHECK-SAME: depth_multiplier = 4 : i32 // CHECK: return %[[dconv:.*]] // PerTensor-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<64x3x3x3xf32> // PerTensor-DAG: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<64xf32> -// PerTensor: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>} +// PerTensor: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>}> // PerTensor: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>) -> tensor<64x3x3x3xf32> -// PerTensor: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq_w]], %[[b]]) { +// PerTensor: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq_w]], %[[b]]) <{ // PerTensor-NOT: asymmetric_quantize_inputs = true // PerTensor-SAME: depth_multiplier = 4 : i32 // PerTensor: return %[[dconv:.*]] // MinElement: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<64x3x3x3xf32> // MinElement: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<64xf32> -// MinElement: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[w]], %[[b]]) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<64x3x3x3xf32>, tensor<64xf32>) -> tensor<1x112x112x64xf32> +// MinElement: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[w]], %[[b]]) <{depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32}> : (tensor<1x224x224x3xf32>, tensor<64x3x3x3xf32>, tensor<64xf32>) -> tensor<1x112x112x64xf32> // MinElement: return %[[dconv:.*]] // Float16-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<64x3x3x3xf16> // Float16-DAG: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<64xf16> // Float16: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<64x3x3x3xf16>) -> tensor<64x3x3x3xf32> // Float16: %[[dq_b:.*]] = "tfl.dequantize"(%[[b]]) : (tensor<64xf16>) -> tensor<64xf32> -// Float16: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq_w]], %[[dq_b]]) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<64x3x3x3xf32>, tensor<64xf32>) -> tensor<1x112x112x64xf32> +// Float16: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq_w]], %[[dq_b]]) <{depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32}> : (tensor<1x224x224x3xf32>, tensor<64x3x3x3xf32>, tensor<64xf32>) -> tensor<1x112x112x64xf32> // Float16: return %[[dconv:.*]] } @@ -103,19 +103,19 @@ func.func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x11 func.return %fc : tensor<1x112x112x512xf32> // CHECK-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<512x12xf32> -// CHECK-DAG: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00, +// CHECK-DAG: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00, // CHECK-DAG: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00, // CHECK-DAG: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<512xf32> -// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq_w]], %[[b]]) { +// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq_w]], %[[b]]) <{ // CHECK-NOT: fused_activation_function = "NONE" // CHECK-SAME: asymmetric_quantize_inputs = true // CHECK: return %[[fc:.*]] // PerTensor-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<512x12xf32> -// PerTensor-DAG: %[[q_w:.*]]= "tfl.quantize"(%[[w:.*]]) {qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>>} +// PerTensor-DAG: %[[q_w:.*]]= "tfl.quantize"(%[[w:.*]]) <{qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>>}> // PerTensor-DAG: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w:.*]]) : (tensor<512x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<512x12xf32> // PerTensor-DAG: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<512xf32> -// PerTensor: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq_w:.*]], %[[b:.*]]) { +// PerTensor: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq_w:.*]], %[[b:.*]]) <{ // PerTensor-NOT: fused_activation_function = "NONE" // PerTensor-SAME: asymmetric_quantize_inputs = true // PerTensor: return %[[fc:.*]] @@ -132,21 +132,21 @@ func.func @QuantizeBatchMatmulWithActConst(%arg0: tensor<1x3x3x512xf32>) -> tens func.return %mm_s : tensor<1x3x3x2xf32> // CHECK: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<512x2xf32> -// CHECK: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<512x2x!quant.uniform:f32, 1.000000e+00>>} +// CHECK: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<512x2x!quant.uniform:f32, 1.000000e+00>>}> // CHECK: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<512x2x!quant.uniform:f32, 1.000000e+00>>) -> tensor<512x2xf32> -// CHECK: %[[mm:.*]] = "tfl.batch_matmul"(%arg0, %[[dq_w]]) {adj_x = false, adj_y = false +// CHECK: %[[mm:.*]] = "tfl.batch_matmul"(%arg0, %[[dq_w]]) <{adj_x = false, adj_y = false // CHECK-SAME: , asymmetric_quantize_inputs = true // CHECK: return %[[mm:.*]] // PerTensor: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<512x2xf32> -// PerTensor: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<512x2x!quant.uniform:f32, 1.000000e+00>>} +// PerTensor: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<512x2x!quant.uniform:f32, 1.000000e+00>>}> // PerTensor: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<512x2x!quant.uniform:f32, 1.000000e+00>>) -> tensor<512x2xf32> -// PerTensor: %[[mm:.*]] = "tfl.batch_matmul"(%arg0, %[[dq_w]]) {adj_x = false, adj_y = false +// PerTensor: %[[mm:.*]] = "tfl.batch_matmul"(%arg0, %[[dq_w]]) <{adj_x = false, adj_y = false // PerTensor-SAME: , asymmetric_quantize_inputs = true // PerTensor: return %[[mm:.*]] // MinElement: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<512x2xf32> -// MinElement: %[[mm:.*]] = "tfl.batch_matmul"(%arg0, %[[w]]) {adj_x = false, adj_y = false} : (tensor<1x3x3x512xf32>, tensor<512x2xf32>) -> tensor<1x3x3x2xf32> +// MinElement: %[[mm:.*]] = "tfl.batch_matmul"(%arg0, %[[w]]) <{adj_x = false, adj_y = false}> : (tensor<1x3x3x512xf32>, tensor<512x2xf32>) -> tensor<1x3x3x2xf32> // MinElement: return %[[mm:.*]] } @@ -160,11 +160,11 @@ func.func @NotQuantizeBatchMatmulWithConstAct(%arg0: tensor<1x1x3x512xf32>) -> t func.return %mm_s : tensor<1x1x12x3xf32> // CHECK: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<1x1x12x512xf32> -// CHECK: %[[mm:.*]] = "tfl.batch_matmul"(%[[w]], %arg0) {adj_x = false, adj_y = true} +// CHECK: %[[mm:.*]] = "tfl.batch_matmul"(%[[w]], %arg0) <{adj_x = false, adj_y = true}> // CHECK: return %[[mm:.*]] // PerTensor: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<1x1x12x512xf32> -// PerTensor: %[[mm:.*]] = "tfl.batch_matmul"(%[[w]], %arg0) {adj_x = false, adj_y = true} +// PerTensor: %[[mm:.*]] = "tfl.batch_matmul"(%[[w]], %arg0) <{adj_x = false, adj_y = true}> // PerTensor: return %[[mm:.*]] } @@ -176,10 +176,10 @@ func.func @NotQuantizeBatchMatmulWithActAct(%arg0: tensor<1x3x3x512xf32>) -> ten %mm_s = "quantfork.stats"(%mm) {layerStats = dense<[0.000000e+00, 1.000000e+01]> : tensor<2xf32>} : (tensor<1x3x3x3xf32>) -> tensor<1x3x3x3xf32> func.return %mm : tensor<1x3x3x3xf32> -// CHECK: %[[mm:.*]] = "tfl.batch_matmul"(%arg0, %arg0) {adj_x = false, adj_y = true} +// CHECK: %[[mm:.*]] = "tfl.batch_matmul"(%arg0, %arg0) <{adj_x = false, adj_y = true}> // CHECK: return %[[mm:.*]] -// PerTensor: %[[mm:.*]] = "tfl.batch_matmul"(%arg0, %arg0) {adj_x = false, adj_y = true} +// PerTensor: %[[mm:.*]] = "tfl.batch_matmul"(%arg0, %arg0) <{adj_x = false, adj_y = true}> // PerTensor: return %[[mm:.*]] } @@ -212,31 +212,31 @@ func.func @QuantizeCustomOp(%arg0: tensor<1x1x1x1xf32>) -> (tensor<*xf32>, tenso // CHECK: %[[w_1:.*]] = arith.constant dense<1.270000e+02> : tensor<4096x1x1x1xf32> // CHECK: %[[w_2:.*]] = arith.constant dense<1.270000e+02> : tensor<128x1x1x1xf32> // CHECK: %[[b:.*]] = arith.constant dense<1.270000e+02> : tensor<2048x1x1x1xf32> -// CHECK: %[[custom_1:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) {custom_code = "CustomTestOp", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> -// CHECK: %[[custom_2:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) {custom_code = "CustomTestOp2", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> -// CHECK: %[[custom_3:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) {custom_code = "CustomTestOp3", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> +// CHECK: %[[custom_1:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) <{custom_code = "CustomTestOp", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> +// CHECK: %[[custom_2:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) <{custom_code = "CustomTestOp2", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> +// CHECK: %[[custom_3:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) <{custom_code = "CustomTestOp3", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> // CHECK: return %[[custom_1:.*]], %[[custom_2:.*]], %[[custom_3:.*]] // CustomOp-DAG: %[[w_1:.*]] = arith.constant dense<1.270000e+02> : tensor<4096x1x1x1xf32> // CustomOp-DAG: %[[w_2:.*]] = arith.constant dense<1.270000e+02> : tensor<128x1x1x1xf32> // CustomOp-DAG: %[[b:.*]] = arith.constant dense<1.270000e+02> : tensor<2048x1x1x1xf32> -// CustomOp-DAG: %[[q_w1:.*]] = "tfl.quantize"(%[[w_1]]) {qtype = tensor<4096x1x1x1x!quant.uniform:f32, 1.000000e+00>>} : (tensor<4096x1x1x1xf32>) -> tensor<4096x1x1x1x!quant.uniform:f32, 1.000000e+00>> -// CustomOp-DAG: %[[q_b:.*]] = "tfl.quantize"(%[[b]]) {qtype = tensor<2048x1x1x1x!quant.uniform:f32, 1.000000e+00>>} : (tensor<2048x1x1x1xf32>) -> tensor<2048x1x1x1x!quant.uniform:f32, 1.000000e+00>> +// CustomOp-DAG: %[[q_w1:.*]] = "tfl.quantize"(%[[w_1]]) <{qtype = tensor<4096x1x1x1x!quant.uniform:f32, 1.000000e+00>>}> : (tensor<4096x1x1x1xf32>) -> tensor<4096x1x1x1x!quant.uniform:f32, 1.000000e+00>> +// CustomOp-DAG: %[[q_b:.*]] = "tfl.quantize"(%[[b]]) <{qtype = tensor<2048x1x1x1x!quant.uniform:f32, 1.000000e+00>>}> : (tensor<2048x1x1x1xf32>) -> tensor<2048x1x1x1x!quant.uniform:f32, 1.000000e+00>> // CustomOp-DAG: %[[dq_w1:.*]] = "tfl.dequantize"(%[[q_w1]]) : (tensor<4096x1x1x1x!quant.uniform:f32, 1.000000e+00>>) -> tensor<4096x1x1x1xf32> // CustomOp: %[[dq_b:.*]] = "tfl.dequantize"(%[[q_b]]) : (tensor<2048x1x1x1x!quant.uniform:f32, 1.000000e+00>>) -> tensor<2048x1x1x1xf32> -// CustomOp: %[[custom_1:.*]] = "tfl.custom"(%arg0, %[[dq_w1]], %[[w_2]], %[[dq_b]]) {custom_code = "CustomTestOp", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> -// CustomOp: %[[custom_2:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) {custom_code = "CustomTestOp2", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> -// CustomOp: %[[custom_3:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[dq_b]]) {custom_code = "CustomTestOp3", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> +// CustomOp: %[[custom_1:.*]] = "tfl.custom"(%arg0, %[[dq_w1]], %[[w_2]], %[[dq_b]]) <{custom_code = "CustomTestOp", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> +// CustomOp: %[[custom_2:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) <{custom_code = "CustomTestOp2", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> +// CustomOp: %[[custom_3:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[dq_b]]) <{custom_code = "CustomTestOp3", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> // CustomOp: return %[[custom_1:.*]], %[[custom_2:.*]], %[[custom_3:.*]] // MinElement-DAG: %[[w_1:.*]] = arith.constant dense<1.270000e+02> : tensor<4096x1x1x1xf32> -// MinElement-DAG: %[[q_w1:.*]] = "tfl.quantize"(%[[w_1]]) {qtype = tensor<4096x1x1x1x!quant.uniform:f32, 1.000000e+00>>} : (tensor<4096x1x1x1xf32>) -> tensor<4096x1x1x1x!quant.uniform:f32, 1.000000e+00>> +// MinElement-DAG: %[[q_w1:.*]] = "tfl.quantize"(%[[w_1]]) <{qtype = tensor<4096x1x1x1x!quant.uniform:f32, 1.000000e+00>>}> : (tensor<4096x1x1x1xf32>) -> tensor<4096x1x1x1x!quant.uniform:f32, 1.000000e+00>> // MinElement-DAG: %[[dq_w1:.*]] = "tfl.dequantize"(%[[q_w1]]) : (tensor<4096x1x1x1x!quant.uniform:f32, 1.000000e+00>>) -> tensor<4096x1x1x1xf32> // MinElement-DAG: %[[w_2:.*]] = arith.constant dense<1.270000e+02> : tensor<128x1x1x1xf32> // MinElement-DAG: %[[b:.*]] = arith.constant dense<1.270000e+02> : tensor<2048x1x1x1xf32> -// MinElement: %[[custom_1:.*]] = "tfl.custom"(%arg0, %[[dq_w1]], %[[w_2]], %[[b]]) {custom_code = "CustomTestOp", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> -// MinElement: %[[custom_2:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) {custom_code = "CustomTestOp2", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> -// MinElement: %[[custom_3:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) {custom_code = "CustomTestOp3", custom_option = #tfl} : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> +// MinElement: %[[custom_1:.*]] = "tfl.custom"(%arg0, %[[dq_w1]], %[[w_2]], %[[b]]) <{custom_code = "CustomTestOp", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> +// MinElement: %[[custom_2:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) <{custom_code = "CustomTestOp2", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> +// MinElement: %[[custom_3:.*]] = "tfl.custom"(%arg0, %[[w_1]], %[[w_2]], %[[b]]) <{custom_code = "CustomTestOp3", custom_option = #tfl}> : (tensor<1x1x1x1xf32>, tensor<4096x1x1x1xf32>, tensor<128x1x1x1xf32>, tensor<2048x1x1x1xf32>) -> tensor<*xf32> // MinElement: return %[[custom_1:.*]], %[[custom_2:.*]], %[[custom_3:.*]] } @@ -252,18 +252,18 @@ func.func @QuantizeTransposeConvWeightOnly(%arg0: tensor<32x4x4x128xf32>, %arg1: // CHECK-DAG: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<1x32x42x128xf32> // CHECK-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<1x32x42x128xf32> -// CHECK: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>>} : (tensor<1x32x42x128xf32>) -> tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>> +// CHECK: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>>}> : (tensor<1x32x42x128xf32>) -> tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>> // CHECK: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>>) -> tensor<1x32x42x128xf32> -// CHECK: %[[tconv:.*]] = "tfl.transpose_conv"(%arg1, %[[dq_w:.*]], %arg0, %[[b:.*]]) { +// CHECK: %[[tconv:.*]] = "tfl.transpose_conv"(%arg1, %[[dq_w:.*]], %arg0, %[[b:.*]]) <{ // CHECK-NOT: asymmetric_quantize_inputs = true // CHECK-SAME: padding = "SAME" // CHECK: return %[[tconv:.*]] // PerTensor-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<1x32x42x128xf32> // PerTensor-DAG: %[[b:.*]]= arith.constant dense<0.000000e+00> : tensor<1x32x42x128xf32> -// PerTensor: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>>} : (tensor<1x32x42x128xf32>) -> tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>> +// PerTensor: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>>}> : (tensor<1x32x42x128xf32>) -> tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>> // PerTensor: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32> -// PerTensor: %[[tconv:.*]] = "tfl.transpose_conv"(%arg1, %[[dq_w:.*]], %arg0, %[[b:.*]]) { +// PerTensor: %[[tconv:.*]] = "tfl.transpose_conv"(%arg1, %[[dq_w:.*]], %arg0, %[[b:.*]]) <{ // PerTensor-NOT: asymmetric_quantize_inputs = true // PerTensor-SAME: padding = "SAME" // PerTensor: return %[[tconv:.*]] @@ -278,13 +278,13 @@ func.func @QuantizeGatherWeightOnly(%arg0: tensor<3xi32>) -> tensor<3x3x3x3xf32> func.return %emb_s : tensor<3x3x3x3xf32> // CHECK: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<64x3x3x3xf32> -// CHECK: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> +// CHECK: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> // CHECK: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>) -> tensor<64x3x3x3xf32> // CHECK: %[[emb:.*]] = "tfl.gather"(%[[dq_w]], %arg0) // CHECK: return %[[emb:.*]] // PerTensor: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<64x3x3x3xf32> -// PerTensor: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> +// PerTensor: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> // PerTensor: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>) -> tensor<64x3x3x3xf32> // PerTensor: %[[emb:.*]] = "tfl.gather"(%[[dq_w]], %arg0) // PerTensor: return %[[emb:.*]] @@ -312,10 +312,10 @@ func.func @NotQuantizeConv3D(%arg0: tensor) -> tensor // CHECK-DAG: %[[out_ch:.*]] = arith.constant dense<16> : tensor<1xi64> -// CHECK-DAG: %[[const:.*]] = "tfl.no_value"() {value} : () -> none +// CHECK-DAG: %[[const:.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<3x3x3x8x16xf32> // CHECK-DAG: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> -// CHECK: %[[conv3d:.*]] = "tfl.conv_3d"(%arg0, %[[w]], %[[const]]) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor, tensor<3x3x3x8x16xf32>, none) -> tensor +// CHECK: %[[conv3d:.*]] = "tfl.conv_3d"(%arg0, %[[w]], %[[const]]) <{dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor, tensor<3x3x3x8x16xf32>, none) -> tensor // CHECK: %2 = "tfl.shape"(%[[conv3d]]) : (tensor) -> tensor<5xi64> // CHECK: %3 = "tfl.broadcast_args"(%2, %[[out_ch]]) : (tensor<5xi64>, tensor<1xi64>) -> tensor<5xi64> // CHECK: %4 = "tfl.broadcast_to"(%[[conv3d]], %3) : (tensor, tensor<5xi64>) -> tensor @@ -324,10 +324,10 @@ func.func @NotQuantizeConv3D(%arg0: tensor) -> tensor // PerTensor: %[[out_ch:.*]] = arith.constant dense<16> : tensor<1xi64> -// PerTensor: %[[const:.*]] = "tfl.no_value"() {value} : () -> none +// PerTensor: %[[const:.*]] = "tfl.no_value"() <{value}> : () -> none // PerTensor: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<3x3x3x8x16xf32> // PerTensor: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> -// PerTensor: %[[conv3d:.*]] = "tfl.conv_3d"(%arg0, %[[w]], %[[const]]) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor, tensor<3x3x3x8x16xf32>, none) -> tensor +// PerTensor: %[[conv3d:.*]] = "tfl.conv_3d"(%arg0, %[[w]], %[[const]]) <{dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor, tensor<3x3x3x8x16xf32>, none) -> tensor // PerTensor: %2 = "tfl.shape"(%[[conv3d]]) : (tensor) -> tensor<5xi64> // PerTensor: %3 = "tfl.broadcast_args"(%2, %[[out_ch]]) : (tensor<5xi64>, tensor<1xi64>) -> tensor<5xi64> // PerTensor: %4 = "tfl.broadcast_to"(%[[conv3d]], %3) : (tensor, tensor<5xi64>) -> tensor @@ -338,10 +338,10 @@ func.func @NotQuantizeConv3D(%arg0: tensor) -> tensor : tensor<1xi64> // Float16-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<3x3x3x8x16xf16> // Float16-DAG: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf16> -// Float16-DAG: %[[const:.*]] = "tfl.no_value"() {value} : () -> none +// Float16-DAG: %[[const:.*]] = "tfl.no_value"() <{value}> : () -> none // Float16-DAG: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<3x3x3x8x16xf16>) -> tensor<3x3x3x8x16xf32> // Float16-DAG: %[[dq_b:.*]] = "tfl.dequantize"(%[[b]]) : (tensor<16xf16>) -> tensor<16xf32> -// Float16: %[[conv3d:.*]] = "tfl.conv_3d"(%arg0, %[[dq_w]], %[[const]]) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor, tensor<3x3x3x8x16xf32>, none) -> tensor +// Float16: %[[conv3d:.*]] = "tfl.conv_3d"(%arg0, %[[dq_w]], %[[const]]) <{dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor, tensor<3x3x3x8x16xf32>, none) -> tensor // Float16: %4 = "tfl.shape"(%[[conv3d]]) : (tensor) -> tensor<5xi64> // Float16: %5 = "tfl.broadcast_args"(%4, %[[out_ch]]) : (tensor<5xi64>, tensor<1xi64>) -> tensor<5xi64> // Float16: %6 = "tfl.broadcast_to"(%[[conv3d]], %5) : (tensor, tensor<5xi64>) -> tensor @@ -367,24 +367,24 @@ func.func @QuantizeMultiUses(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112 // CHECK-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<64x3x3x3xf32> // CHECK-DAG: %[[b:.*]] = arith.constant dense<-1.23697901> : tensor<64xf32> -// CHECK-DAG: %[[q_w1:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<64x3x3x3x!quant.uniform:f32:3, {1.000000e+00,1.000000e+00,1.000000e+00} -// CHECK-DAG: %[[q_w2:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<64x3x3x3x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00 +// CHECK-DAG: %[[q_w1:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<64x3x3x3x!quant.uniform:f32:3, {1.000000e+00,1.000000e+00,1.000000e+00}> +// CHECK-DAG: %[[q_w2:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<64x3x3x3x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00 // CHECK-DAG: %[[dq_w1:.*]] = "tfl.dequantize"(%[[q_w1]]) // CHECK-DAG: %[[dq_w2:.*]] = "tfl.dequantize"(%[[q_w2]]) // CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w2]], %[[b]]) // CHECK: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq_w1]], %[[b]]) -// CHECK: %[[bmm:.*]] = "tfl.batch_matmul"(%[[conv]], %[[dconv]]) {adj_x = false, adj_y = true +// CHECK: %[[bmm:.*]] = "tfl.batch_matmul"(%[[conv]], %[[dconv]]) <{adj_x = false, adj_y = true // CHECK-NOT: , asymmetric_quantize_inputs = true // CHECK-SAME: } // CHECK: return %[[bmm:.*]] // PerTensor-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<64x3x3x3xf32> // PerTensor-DAG: %[[b:.*]] = arith.constant dense<-1.23697901> : tensor<64xf32> -// PerTensor: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>} +// PerTensor: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>}> // PerTensor: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>) -> tensor<64x3x3x3xf32> // PerTensor: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w]], %[[b]]) // PerTensor: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq_w]], %[[b]]) -// PerTensor: %[[bmm:.*]] = "tfl.batch_matmul"(%[[conv]], %[[dconv]]) {adj_x = false, adj_y = true +// PerTensor: %[[bmm:.*]] = "tfl.batch_matmul"(%[[conv]], %[[dconv]]) <{adj_x = false, adj_y = true // PerTensor-NOT: , asymmetric_quantize_inputs = true // PerTensor-SAME: } // PerTensor: return %[[bmm:.*]] @@ -395,7 +395,7 @@ func.func @QuantizeMultiUses(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112 // Float16-DAG: %[[dq_b:.*]] = "tfl.dequantize"(%[[b:.*]]) : (tensor<64xf16>) -> tensor<64xf32> // Float16: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w]], %[[dq_b]]) // Float16: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq_w]], %[[dq_b]]) -// Float16: %[[bmm:.*]] = "tfl.batch_matmul"(%[[conv]], %[[dconv]]) {adj_x = false, adj_y = true +// Float16: %[[bmm:.*]] = "tfl.batch_matmul"(%[[conv]], %[[dconv]]) <{adj_x = false, adj_y = true // Float16: return %[[bmm:.*]] } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir index 15ede0019e12d6..31d31c656b1d5f 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir @@ -58,7 +58,7 @@ func.func @QuantizeUnidirectionalLstmFullPerTensor(%arg0: tensor<1x2x3xf32>) -> // CHECK-DAG: %[[input_6:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0047244096365500624>>) -> tensor<1x1xf32> // CHECK-DAG: %[[input_7:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0055118109297564652>>) -> tensor<1x1xf32> // CHECK-DAG: %[[input_8:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0062992126922907796>>) -> tensor<1x1xf32> -// CHECK-DAG: %[[input_9:.*]] = "tfl.no_value"() {value} : () -> none +// CHECK-DAG: %[[input_9:.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK-DAG: %[[input_10:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3x!quant.uniform>) -> tensor<3xf32> // CHECK-DAG: %[[input_11:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3x!quant.uniform>) -> tensor<3xf32> // CHECK-DAG: %[[input_12:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3x!quant.uniform>) -> tensor<3xf32> @@ -73,7 +73,7 @@ func.func @QuantizeUnidirectionalLstmFullPerTensor(%arg0: tensor<1x2x3xf32>) -> // CHECK-SAME: %[[input_10]], %[[input_11]], %[[input_12]], %[[input_13]], // CHECK-SAME: %[[input_9]], %[[input_9]], // CHECK-SAME: %[[input_14]], %[[input_15]], -// CHECK-SAME: %[[input_9]], %[[input_9]], %[[input_9]], %[[input_9]]) { +// CHECK-SAME: %[[input_9]], %[[input_9]], %[[input_9]], %[[input_9]]) <{ // CHECK-SAME: asymmetric_quantize_inputs = false, // CHECK-SAME: cell_clip = 1.000000e+01 : f32, // CHECK-SAME: effective_hidden_scale_intermediate = tensor<0x!quant.uniform>, @@ -83,7 +83,7 @@ func.func @QuantizeUnidirectionalLstmFullPerTensor(%arg0: tensor<1x2x3xf32>) -> // CHECK-SAME: input_to_input_intermediate = tensor<0xf32>, // CHECK-SAME: input_to_output_intermediate = tensor<0xf32>, // CHECK-SAME: proj_clip = 0.000000e+00 : f32, -// CHECK-SAME: time_major = false} : ( +// CHECK-SAME: time_major = false}> : ( // CHECK-SAME: tensor<1x2x3xf32>, // CHECK-SAME: tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, // CHECK-SAME: tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, @@ -93,7 +93,7 @@ func.func @QuantizeUnidirectionalLstmFullPerTensor(%arg0: tensor<1x2x3xf32>) -> // CHECK-SAME: tensor<1x3xf32>, tensor<1x3xf32>, // CHECK-SAME: none, none, none, none) // CHECK-SAME: -> tensor<1x2x3xf32> -// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<1x2x3x!quant.uniform>, volatile} : (tensor<1x2x3xf32>) -> tensor<1x2x3x!quant.uniform> +// CHECK: "tfl.quantize"(%[[lstm]]) <{qtype = tensor<1x2x3x!quant.uniform>}> {volatile} : (tensor<1x2x3xf32>) -> tensor<1x2x3x!quant.uniform> } @@ -176,7 +176,7 @@ func.func @QuantizeUnidirectionalLstmFullPerAxis(%arg0: tensor<1x2x3xf32>) -> (t // CHECK-DAG: %[[input_6:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0047244096365500624>>) -> tensor<1x1xf32> // CHECK-DAG: %[[input_7:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0055118109297564652>>) -> tensor<1x1xf32> // CHECK-DAG: %[[input_8:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0062992126922907796>>) -> tensor<1x1xf32> -// CHECK-DAG: %[[input_9:.*]] = "tfl.no_value"() {value} : () -> none +// CHECK-DAG: %[[input_9:.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK-DAG: %[[input_10:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3x!quant.uniform>) -> tensor<3xf32> // CHECK-DAG: %[[input_11:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3x!quant.uniform>) -> tensor<3xf32> // CHECK-DAG: %[[input_12:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3x!quant.uniform>) -> tensor<3xf32> @@ -191,14 +191,14 @@ func.func @QuantizeUnidirectionalLstmFullPerAxis(%arg0: tensor<1x2x3xf32>) -> (t // CHECK-SAME: %[[input_10]], %[[input_11]], %[[input_12]], %[[input_13]], // CHECK-SAME: %[[input_9]], %[[input_9]], // CHECK-SAME: %[[input_14]], %[[input_15]], -// CHECK-SAME: %[[input_9]], %[[input_9]], %[[input_9]], %[[input_9]]) { +// CHECK-SAME: %[[input_9]], %[[input_9]], %[[input_9]], %[[input_9]]) <{ // CHECK-SAME: asymmetric_quantize_inputs = false, // CHECK-SAME: cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform>, // CHECK-SAME: fused_activation_function = "TANH", // CHECK-SAME: input_to_cell_intermediate = tensor<0xf32>, // CHECK-SAME: input_to_forget_intermediate = tensor<0xf32>, // CHECK-SAME: input_to_input_intermediate = tensor<0xf32>, -// CHECK-SAME: input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : ( +// CHECK-SAME: input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false}> : ( // CHECK-SAME: tensor<1x2x3xf32>, // CHECK-SAME: tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, // CHECK-SAME: tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, @@ -208,7 +208,7 @@ func.func @QuantizeUnidirectionalLstmFullPerAxis(%arg0: tensor<1x2x3xf32>) -> (t // CHECK-SAME: tensor<1x3xf32>, tensor<1x3xf32>, // CHECK-SAME: none, none, none, none) // CHECK-SAME: -> tensor<1x2x3xf32> -// CHECK: %32 = "tfl.quantize"(%31) {qtype = tensor<1x2x3x!quant.uniform>, volatile} : (tensor<1x2x3xf32>) -> tensor<1x2x3x!quant.uniform> +// CHECK: %32 = "tfl.quantize"(%31) <{qtype = tensor<1x2x3x!quant.uniform>}> {volatile} : (tensor<1x2x3xf32>) -> tensor<1x2x3x!quant.uniform> } @@ -219,10 +219,10 @@ func.func @QuantizeFixedOutputRangeInterfaceOpSoftmax(%arg0: tensor<1x1xf32>) -> %2 = "quantfork.stats"(%1) {layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> func.return %2 : tensor<1x1xf32> -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[sm:.*]] = "tfl.softmax"(%[[dq1]]) {{{.*}}} : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[sm]]) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[sm:.*]] = "tfl.softmax"(%[[dq1]]) <{{{.*}}}> : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[sm]]) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> } @@ -233,10 +233,10 @@ func.func @QuantizeFixedOutputRangeInterfaceOpL2Normalization(%arg0: tensor<1x1x %2 = "quantfork.stats"(%1) {layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> func.return %2 : tensor<1x1xf32> -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[l2:.*]] = "tfl.l2_normalization"(%[[dq1]]) {{{.*}}} : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[l2]]) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[l2:.*]] = "tfl.l2_normalization"(%[[dq1]]) <{{{.*}}}> : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[l2]]) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> } @@ -247,10 +247,10 @@ func.func @QuantizeFixedOutputRangeInterfaceOpLogistic(%arg0: tensor<1x1xf32>) - %2 = "quantfork.stats"(%1) {layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> func.return %2 : tensor<1x1xf32> -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> // CHECK-NEXT: %[[lo:.*]] = "tfl.logistic"(%[[dq1]]) : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[lo]]) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[lo]]) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> } @@ -261,10 +261,10 @@ func.func @QuantizeFixedOutputRangeInterfaceOpTanh(%arg0: tensor<1x1xf32>) -> (t %2 = "quantfork.stats"(%1) {layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> func.return %2 : tensor<1x1xf32> -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> // CHECK-NEXT: %[[ta:.*]] = "tfl.tanh"(%[[dq1]]) : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[ta]]) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[ta]]) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> } @@ -277,10 +277,10 @@ func.func @QuantizeReshapeOp(%arg0: tensor<1x1x3xf32>) -> (tensor<1x3xf32>) { func.return %4 : tensor<1x3xf32> // CHECK: %[[cst:.*]] = arith.constant dense<[-1, 3]> : tensor<2xi32> -// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x1x3x!quant.uniform>, volatile} : (tensor<1x1x3xf32>) -> tensor<1x1x3x!quant.uniform> +// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x1x3x!quant.uniform>}> {volatile} : (tensor<1x1x3xf32>) -> tensor<1x1x3x!quant.uniform> // CHECK-NEXT: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) : (tensor<1x1x3x!quant.uniform>) -> tensor<1x1x3xf32> // CHECK-NEXT: %[[rs:.*]] = "tfl.reshape"(%[[dq1]], %[[cst]]) : (tensor<1x1x3xf32>, tensor<2xi32>) -> tensor<1x3xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[rs]]) {qtype = tensor<1x3x!quant.uniform>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[rs]]) <{qtype = tensor<1x3x!quant.uniform>}> {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> // CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK-NEXT: return %[[dq2]] : tensor<1x3xf32> } @@ -295,15 +295,15 @@ func.func @QuantizeFullyConnectedOp(%arg0: tensor<1x3xf32>) -> (tensor<1x1xf32>) func.return %5 : tensor<1x1xf32> // CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : tensor<1xf32> -// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<1x!quant.uniform>, volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> +// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%[[cst]]) <{qtype = tensor<1x!quant.uniform>}> {volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> // CHECK-NEXT: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) : (tensor<1x!quant.uniform>) -> tensor<1xf32> // CHECK-NEXT: %[[cst_0:.*]] = arith.constant dense<{{.*}}> : tensor<1x3xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<1x3x!quant.uniform:f32:0, {{.*}}>>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform:f32:0, {{.*}}>> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cst_0]]) <{qtype = tensor<1x3x!quant.uniform:f32:0, {{.*}}>>}> {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform:f32:0, {{.*}}>> // CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x3x!quant.uniform:f32:0, {{.*}}>>) -> tensor<1x3xf32> -// CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x3x!quant.uniform>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> +// CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x3x!quant.uniform>}> {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> // CHECK-NEXT: %[[dq3:.*]] = "tfl.dequantize"(%[[q3]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> -// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%[[dq3]], %[[dq2]], %[[dq1]]) {{{.*}}} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1xf32>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[q4:.*]] = "tfl.quantize"(%[[fc]]) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%[[dq3]], %[[dq2]], %[[dq1]]) <{{{.*}}}> : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1xf32>) -> tensor<1x1xf32> +// CHECK-NEXT: %[[q4:.*]] = "tfl.quantize"(%[[fc]]) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[dq4:.*]] = "tfl.dequantize"(%[[q4]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> // CHECK-NEXT: return %[[dq4]] : tensor<1x1xf32> } @@ -321,19 +321,19 @@ func.func @QuantizeReshapeAndFullyConnectedOp(%arg0: tensor<1x1x3xf32>) -> (tens func.return %8 : tensor<1x1xf32> // CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : tensor<1xf32> -// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<1x!quant.uniform>, volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> +// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%[[cst]]) <{qtype = tensor<1x!quant.uniform>}> {volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> // CHECK-NEXT: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) : (tensor<1x!quant.uniform>) -> tensor<1xf32> // CHECK-NEXT: %[[cst_0:.*]] = arith.constant dense<{{.*}}> : tensor<1x3xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<1x3x!quant.uniform:f32:0, {{.*}}>>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform:f32:0, {{.*}}>> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cst_0]]) <{qtype = tensor<1x3x!quant.uniform:f32:0, {{.*}}>>}> {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform:f32:0, {{.*}}>> // CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x3x!quant.uniform:f32:0, {{.*}}>>) -> tensor<1x3xf32> // CHECK-NEXT: %[[cst_1:.*]] = arith.constant dense<[-1, 3]> : tensor<2xi32> -// CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x1x3x!quant.uniform>, volatile} : (tensor<1x1x3xf32>) -> tensor<1x1x3x!quant.uniform> +// CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x1x3x!quant.uniform>}> {volatile} : (tensor<1x1x3xf32>) -> tensor<1x1x3x!quant.uniform> // CHECK-NEXT: %[[dq3:.*]] = "tfl.dequantize"(%[[q3]]) : (tensor<1x1x3x!quant.uniform>) -> tensor<1x1x3xf32> // CHECK-NEXT: %[[rs:.*]] = "tfl.reshape"(%[[dq3]], %[[cst_1]]) : (tensor<1x1x3xf32>, tensor<2xi32>) -> tensor<1x3xf32> -// CHECK-NEXT: %[[q4:.*]] = "tfl.quantize"(%[[rs]]) {qtype = tensor<1x3x!quant.uniform>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> +// CHECK-NEXT: %[[q4:.*]] = "tfl.quantize"(%[[rs]]) <{qtype = tensor<1x3x!quant.uniform>}> {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> // CHECK-NEXT: %[[dq4:.*]] = "tfl.dequantize"(%[[q4]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> -// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%[[dq4]], %[[dq2]], %[[dq1]]) {{{.*}}} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1xf32>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[q5:.*]] = "tfl.quantize"(%[[fc]]) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%[[dq4]], %[[dq2]], %[[dq1]]) <{{{.*}}}> : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1xf32>) -> tensor<1x1xf32> +// CHECK-NEXT: %[[q5:.*]] = "tfl.quantize"(%[[fc]]) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[dq5:.*]] = "tfl.dequantize"(%[[q5]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> // CHECK-NEXT: return %[[dq5]] : tensor<1x1xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training.mlir index 0b3ea28406825a..bf397c2d1b8463 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training.mlir @@ -29,9 +29,9 @@ func.func @QuantizeLstmCellInput(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf none, none, none, none) -> tensor<1x28x20xf32> %1 = "quantfork.stats"(%0) {layerStats = dense<[-1.0, 2.0]> : tensor<2xf32>} : (tensor<1x28x20xf32>) -> tensor<1x28x20xf32> func.return %1 : tensor<1x28x20xf32> -// CHECK-DAG: %[[none:.*]] = "tfl.no_value"() {value} : () -> none +// CHECK-DAG: %[[none:.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK-DAG: %[[cell_input:.*]] = arith.constant dense<1.000000e+00> : tensor<1x20xf32> -// CHECK-DAG: %[[q:.*]] = "tfl.quantize"(%[[cell_input]]) {qtype = tensor<1x20x!quant.uniform>} : (tensor<1x20xf32>) -> tensor<1x20x!quant.uniform> +// CHECK-DAG: %[[q:.*]] = "tfl.quantize"(%[[cell_input]]) <{qtype = tensor<1x20x!quant.uniform>}> : (tensor<1x20xf32>) -> tensor<1x20x!quant.uniform> // CHECK-DAG: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<1x20x!quant.uniform>) -> tensor<1x20xf32> // Checks if input 19 is correctly passed from a dequantize op. // CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, {{(%[^%,]+, )+}}%[[dq]], %[[none]], %[[none]], %[[none]], %[[none]]) @@ -113,7 +113,7 @@ func.func @QuantizeWithoutNorm(%arg0: tensor<1x1x5xf32>) -> tensor<*xf32> attrib // CHECK-SAME: %[[input_9]], %[[input_10]], %[[input_11]], %[[input_12]], %[[input_13]], %[[input_14]], %[[input_15]], %[[input_16]], %[[input_17]], %[[input_18]], %[[input_19]] // CHECK-SAME: effective_hidden_scale_intermediate = tensor> -// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>, volatile} +// CHECK: "tfl.quantize"(%[[lstm]]) <{qtype = tensor<*x!quant.uniform>}> {volatile} } // CHECK-LABEL: QuantizeLstmCifg @@ -166,7 +166,7 @@ func.func @QuantizeLstmCifg(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes %24 = "quantfork.stats"(%23) {layerStats = dense<[-1.0, 2.0]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> func.return %24 : tensor<*xf32> -// CHECK-DAG: %[[none:.*]] = "tfl.no_value"() {value} : () -> none +// CHECK-DAG: %[[none:.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK-DAG: %[[input_0:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x5x!quant.uniform>) -> tensor<1x5xf32> // CHECK-DAG: %[[input_2:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform:f32, 0.018341723389512912>>) -> tensor<2x5xf32> // CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x5x!quant.uniform:f32, 0.011170119751156785>>) -> tensor<2x5xf32> @@ -190,12 +190,12 @@ func.func @QuantizeLstmCifg(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes // CHECK: %[[lstm:.*]] = "tfl.lstm"(%[[input_0]], %[[none]], %[[input_2]], %[[input_3]], %[[input_4]], %[[none]], %[[input_6]], %[[input_7]], %[[input_8]], // CHECK-SAME: %[[none]], %[[input_10]], %[[input_11]], %[[none]], %[[input_13]], %[[input_14]], %[[input_15]], %[[input_16]], %[[input_17]], %[[input_18]], %[[input_19]], // CHECK-SAME: %[[none]], %[[input_21]], %[[input_22]], %[[input_23]]) -// CHECK-NEXT: effective_hidden_scale_intermediate = tensor> +// CHECK-SAME: effective_hidden_scale_intermediate = tensor> // CHECK-SAME: input_to_cell_intermediate = tensor:f32, 1.2207403790398877E-4>> // CHECK-SAME: input_to_forget_intermediate = tensor:f32, 4.8829615161595508E-4>> // CHECK-SAME: input_to_output_intermediate = tensor:f32, 3.0518509475997192E-5>> -// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>, volatile} +// CHECK: "tfl.quantize"(%[[lstm]]) <{qtype = tensor<*x!quant.uniform>}> {volatile} } // CHECK-LABEL: QuantizeUnidirectionalLstmFull @@ -286,7 +286,7 @@ func.func @QuantizeUnidirectionalLstmFull(%arg0: tensor<1x1x5xf32>) -> tensor<*x // CHECK-SAME: input_to_input_intermediate = tensor:f32, 9.7659230323191015E-4>> // CHECK-SAME: input_to_output_intermediate = tensor:f32, 3.0518509475997192E-5>> -// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>, volatile} +// CHECK: "tfl.quantize"(%[[lstm]]) <{qtype = tensor<*x!quant.uniform>}> {volatile} } // CHECK-LABEL: QuantizeUnidirectionalLstmWithFixedOutputRangedInput @@ -481,13 +481,13 @@ func.func @QuantizeLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes // CHECK: %[[lstm:.*]] = "tfl.lstm"(%[[input_0]], %[[input_1]], %[[input_2]], %[[input_3]], %[[input_4]], %[[input_5]], %[[input_6]], %[[input_7]], %[[input_8]], // CHECK-SAME: %[[input_9]], %[[input_10]], %[[input_11]], %[[input_12]], %[[input_13]], %[[input_14]], %[[input_15]], %[[input_16]], %[[input_17]], %[[input_18]], %[[input_19]], // CHECK-SAME: %[[input_20]], %[[input_21]], %[[input_22]], %[[input_23]]) -// CHECK-NEXT: effective_hidden_scale_intermediate = tensor> +// CHECK-SAME: effective_hidden_scale_intermediate = tensor> // CHECK-SAME: input_to_cell_intermediate = tensor:f32, 1.2207403790398877E-4>> // CHECK-SAME: input_to_forget_intermediate = tensor:f32, 4.8829615161595508E-4>> // CHECK-SAME: input_to_input_intermediate = tensor:f32, 9.7659230323191015E-4>> // CHECK-SAME: input_to_output_intermediate = tensor:f32, 3.0518509475997192E-5>> -// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>, volatile} +// CHECK: "tfl.quantize"(%[[lstm]]) <{qtype = tensor<*x!quant.uniform>}> {volatile} } // CHECK-LABEL: QuantizeSVDF @@ -508,7 +508,7 @@ func.func @QuantizeSVDF(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32> { // CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform>) // CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x4x!quant.uniform:f32, 0.0037514108011770368>>) // CHECK: %[[svdf:.*]] = "tfl.svdf"(%[[input_0]], %[[input_1]], %[[input_2]], %[[input_3]], %[[input_4]]) -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[svdf]]) {qtype = tensor<1x2x!quant.uniform>, volatile} +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[svdf]]) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: return %[[dq]] } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir index 882b335135cf74..c6a2eb88e09e8f 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir @@ -7,7 +7,7 @@ func.func @uint8_to_int8(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> func.return %2 : tensor<2x2xf32> -// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) +// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<2x2x!quant.uniform>}> : (tensor<2x2xf32>) // CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK-NEXT: return %[[dq]] : tensor<2x2xf32> } @@ -18,7 +18,7 @@ func.func @uint8_to_int8_per_axis(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> func.return %2 : tensor<2x2xf32> -// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform>} +// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<2x2x!quant.uniform>}> // CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%0) // CHECK-NEXT: return %[[dq]] : tensor<2x2xf32> } @@ -29,7 +29,7 @@ func.func @uint8_to_int8_narrow_range(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> %2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform:f32, 1.0:255>>) -> tensor<2x2xf32> func.return %2 : tensor<2x2xf32> -// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform:f32, 1.000000e+00:127>>} +// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<2x2x!quant.uniform:f32, 1.000000e+00:127>>}> // CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK-NEXT: return %[[dq]] : tensor<2x2xf32> } @@ -49,9 +49,9 @@ func.func @prepareStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> func.return %1 : tensor<8x4x3xf32> -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform>, volatile} +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<8x4x3x!quant.uniform>}> {volatile} // CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) -// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform>, volatile} +// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) <{qtype = tensor<8x4x3x!quant.uniform>}> {volatile} // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) // CHECK: return %[[dq2]] } @@ -71,9 +71,9 @@ func.func @prepareStatisticsNudge(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> func.return %1 : tensor<8x4x3xf32> -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform>, volatile} +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<8x4x3x!quant.uniform>}> {volatile} // CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) -// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform>, volatile} +// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) <{qtype = tensor<8x4x3x!quant.uniform>}> {volatile} // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) // CHECK: return %[[dq2]] } @@ -85,7 +85,7 @@ func.func @preparePrelu(%arg0: tensor<1x10x10x3xf32>) -> tensor<1x10x10x3xf32> { func.return %prelu : tensor<1x10x10x3xf32> // CHECK: %[[cst:.*]] = arith.constant dense<[{{\[}}[1.66394591, 3.61694336, 2.0382936]]]> : tensor<1x1x3xf32> -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<1x1x3x!quant.uniform:f32, 0.028479868971456691>>, volatile} : (tensor<1x1x3xf32>) -> tensor<1x1x3x!quant.uniform:f32, 0.028479868971456691>> +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) <{qtype = tensor<1x1x3x!quant.uniform:f32, 0.028479868971456691>>}> {volatile} : (tensor<1x1x3xf32>) -> tensor<1x1x3x!quant.uniform:f32, 0.028479868971456691>> // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<1x1x3x!quant.uniform:f32, 0.028479868971456691>>) -> tensor<1x1x3xf32> // CHECK: %[[p:.*]] = "tfl.prelu"(%arg0, %[[dq]]) : (tensor<1x10x10x3xf32>, tensor<1x1x3xf32>) -> tensor<1x10x10x3xf32> // CHECK: return %[[p]] : tensor<1x10x10x3xf32> @@ -98,7 +98,7 @@ func.func @prepareAdd(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func.return %add : tensor<2x2xf32> // CHECK: %[[cst:.*]] = arith.constant dense<[{{\[}}0.000000e+00, 1.000000e+00], [2.000000e+00, 2.550000e+02]]> -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<2x2x!quant.uniform>, volatile} +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) <{qtype = tensor<2x2x!quant.uniform>}> {volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: %[[add:.*]] = tfl.add %arg0, %[[dq]] // CHECK: return %[[add]] @@ -113,13 +113,13 @@ func.func @prepareConv2DSplat(%arg0: tensor<1x5x5x3xf32>) -> tensor<1x5x5x3xf32> func.return %conv : tensor<1x5x5x3xf32> // CHECK: %[[cst:.*]] = arith.constant dense<1.270000e+02> : tensor<3x3x3x3xf32> -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<3x3x3x3x!quant.uniform:f32:0 +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) <{qtype = tensor<3x3x3x3x!quant.uniform:f32:0 // CHECK-SAME: {1.000000e+00,1.000000e+00,1.000000e+00} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq]] // PerTensor: %[[cst:.*]] = arith.constant dense<1.270000e+02> : tensor<3x3x3x3xf32> -// PerTensor: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<3x3x3x3x!quant.uniform:f32, 1.000000e+00>>, volatile} +// PerTensor: %[[q:.*]] = "tfl.quantize"(%[[cst]]) <{qtype = tensor<3x3x3x3x!quant.uniform:f32, 1.000000e+00>>}> {volatile} // PerTensor: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // PerTensor: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq]] } @@ -133,13 +133,13 @@ func.func @prepareConv2D(%arg0: tensor<1x5x5x1xf32>) -> tensor<1x5x5x3xf32> { func.return %conv : tensor<1x5x5x3xf32> // CHECK: %[[cst:.*]] = arith.constant dense<[{{\[\[\[}}0.000000e+00]]], [{{\[\[}}1.270000e+02]]], [{{\[\[}}-1.270000e+02]]]]> -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<3x1x1x1x!quant.uniform:f32:0, -// CHECK-SAME: {3.9370078740157481E-9,1.000000e+00,1.000000e+00}>>, volatile} +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) <{qtype = tensor<3x1x1x1x!quant.uniform:f32:0, +// CHECK-SAME: {3.9370078740157481E-9,1.000000e+00,1.000000e+00}>>}> {volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq]] // PerTensor: %[[cst:.*]] = arith.constant dense<[{{\[\[\[}}0.000000e+00]]], [{{\[\[}}1.270000e+02]]], [{{\[\[}}-1.270000e+02]]]]> -// PerTensor: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<3x1x1x1x!quant.uniform:f32, +// PerTensor: %[[q:.*]] = "tfl.quantize"(%[[cst]]) <{qtype = tensor<3x1x1x1x!quant.uniform:f32, // PerTensor: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // PerTensor: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq]] } @@ -153,13 +153,13 @@ func.func @prepareDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x11 func.return %dc : tensor<1x112x112x32xf32> // CHECK: %[[cst:.*]] = arith.constant dense<1.270000e+02> : tensor<32x3x3x3xf32> -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x3x3x3x!quant.uniform:f32:3 +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) <{qtype = tensor<32x3x3x3x!quant.uniform:f32:3 // CHECK-SAME: {1.000000e+00,1.000000e+00,1.000000e+00} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: %[[conv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq]] // PerTensor: %[[cst:.*]] = arith.constant dense<1.270000e+02> : tensor<32x3x3x3xf32> -// PerTensor: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x3x3x3x!quant.uniform:f32, +// PerTensor: %[[q:.*]] = "tfl.quantize"(%[[cst]]) <{qtype = tensor<32x3x3x3x!quant.uniform:f32, // PerTensor: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // PerTensor: %[[conv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq]] } @@ -173,12 +173,12 @@ func.func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x11 func.return %fc : tensor<1x112x112x4xf32> // CHECK: %[[cst:.*]] = arith.constant dense<1.270000e+02> : tensor<4x12xf32> -// CHECK: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<4x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00}>>, volatile} +// CHECK: %[[q:.*]] = "tfl.quantize"(%cst) <{qtype = tensor<4x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00}>>}> {volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<4x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00}>>) -> tensor<4x12xf32> // CHECK: "tfl.fully_connected"(%arg0, %[[dq]] // PerTensor: %[[cst:.*]] = arith.constant dense<1.270000e+02> : tensor<4x12xf32> -// PerTensor: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<4x12x!quant.uniform:f32, 1.000000e+00>>, volatile} +// PerTensor: %[[q:.*]] = "tfl.quantize"(%cst) <{qtype = tensor<4x12x!quant.uniform:f32, 1.000000e+00>>}> {volatile} // PerTensor: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<4x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<4x12xf32> // PerTensor: "tfl.fully_connected"(%arg0, %[[dq]] } @@ -192,12 +192,12 @@ func.func @QuantizeTransposeConv(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<4x func.return %tc : tensor<1x32x42x128xf32> // CHECK: %[[CST:.*]] = arith.constant dense<1.270000e+02> : tensor<1x32x42x128xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CST]]) {qtype = tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>>, volatile} +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CST]]) <{qtype = tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>>}> {volatile} // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>>) -> tensor<1x32x42x128xf32> // CHECK: "tfl.transpose_conv"(%arg1, %[[DEQUANTIZE]], %arg0, // PerTensor: %[[CST:.*]] = arith.constant dense<1.270000e+02> : tensor<1x32x42x128xf32> -// PerTensor: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CST]]) {qtype = tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>>, volatile} +// PerTensor: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CST]]) <{qtype = tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>>}> {volatile} // PerTensor: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32> // PerTensor: "tfl.transpose_conv"(%arg1, %[[DEQUANTIZE]], %arg0, } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index 3c83883dfea9a4..eea145a0b9f6ba 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -60,7 +60,7 @@ func.func @not_reset_input(%arg0: tensor) -> (tensor>} : (tensor) -> tensor> func.return %0: tensor> -// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor>} +// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor>}> // CHECK-NEXT: return %[[q]] } @@ -92,9 +92,9 @@ func.func @prepareStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> func.return %1 : tensor<8x4x3xf32> -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform>, volatile} +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<8x4x3x!quant.uniform>}> {volatile} // CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) -// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform>, volatile} +// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) <{qtype = tensor<8x4x3x!quant.uniform>}> {volatile} // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) // CHECK: return %[[dq2]] } @@ -106,7 +106,7 @@ func.func @prepareNarrowStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> func.return %0 : tensor<8x4x3xf32> -// CHECK: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform>, volatile} +// CHECK: %[[q:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<8x4x3x!quant.uniform>}> {volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: return %[[dq]] } @@ -123,7 +123,7 @@ func.func @QuantizeConv2DPerChannel(%arg0: tensor<1x224x224x3x!quant.uniform // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<32xf32> -// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform>, volatile} +// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) <{qtype = tensor<32x!quant.uniform>}> {volatile} // CHECK-NEXT: %[[bias:.*]] = "tfl.dequantize"(%[[qbias]]) // CHECK-NEXT: %[[in:.*]] = "tfl.dequantize"(%arg0) // CHECK-NEXT: %[[w:.*]] = "tfl.dequantize"(%arg1) @@ -143,7 +143,7 @@ func.func @QuantizeConv2DPerChannelConst(%arg0: tensor<1x224x224x3x!quant.unifor func.return %conv : tensor<1x112x112x32xf32> // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<32xf32> -// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform>, volatile} +// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) <{qtype = tensor<32x!quant.uniform>}> {volatile} // CHECK-NEXT: %[[bias:.*]] = "tfl.dequantize"(%[[qbias]]) // CHECK-NEXT: %[[in:.*]] = "tfl.dequantize"(%arg0) // CHECK-NEXT: %[[w:.*]] = "tfl.dequantize"(%arg1) @@ -163,7 +163,7 @@ func.func @QuantizeConv2DPerChannels(%arg0: tensor<1x224x224x3x!quant.uniform // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<32xf32> -// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform>, volatile} +// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) <{qtype = tensor<32x!quant.uniform>}> {volatile} // CHECK-NEXT: %[[bias:.*]] = "tfl.dequantize"(%[[qbias]]) // CHECK-NEXT: %[[in:.*]] = "tfl.dequantize"(%arg0) // CHECK-NEXT: %[[w:.*]] = "tfl.dequantize"(%arg1) @@ -183,7 +183,7 @@ func.func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform> // CHECK: %cst = arith.constant dense<-1.23697901> : tensor<32xf32> -// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>, volatile} +// CHECK: %0 = "tfl.quantize"(%cst) <{qtype = tensor<32x!quant.uniform>}> {volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) // CHECK: %2 = "tfl.dequantize"(%arg0) // CHECK: %3 = "tfl.pseudo_qconst"() @@ -205,7 +205,7 @@ func.func @QuantizeFullyConnected(tensor<1x224x224x3x!quant.uniform> // CHECK: %cst = arith.constant dense<-1.23697901> : tensor<32xf32> -// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>, volatile} +// CHECK: %0 = "tfl.quantize"(%cst) <{qtype = tensor<32x!quant.uniform>}> {volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) // CHECK: %2 = "tfl.dequantize"(%arg0) // CHECK: %3 = "tfl.pseudo_qconst"() @@ -227,7 +227,7 @@ func.func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform> // CHECK: %cst = arith.constant dense<-1.23697901> : tensor<32xf32> -// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>, volatile} +// CHECK: %0 = "tfl.quantize"(%cst) <{qtype = tensor<32x!quant.uniform>}> {volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) // CHECK: %2 = "tfl.dequantize"(%arg0) // CHECK: %3 = "tfl.pseudo_qconst"() @@ -308,7 +308,7 @@ func.func @QuantizeStridedSlice(tensor<12x2x2x5x!quant.uniform>, te // CHECK: %0 = "tfl.dequantize"(%arg0) // CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg3) -// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x2x2x5x!quant.uniform>, volatile} +// CHECK: %2 = "tfl.quantize"(%1) <{qtype = tensor<1x2x2x5x!quant.uniform>}> {volatile} // CHECK: %3 = "tfl.dequantize"(%2) // CHECK: return %3 : tensor<1x2x2x5xf32> } @@ -353,7 +353,7 @@ func.func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform>) // CHECK: %1 = "tfl.reshape"(%0, %{{.*}}) : (tensor<1x6x6x16xf32>, tensor<3xi32>) -> tensor<1x36x16xf32> -// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x36x16x!quant.uniform>, volatile} +// CHECK: %2 = "tfl.quantize"(%1) <{qtype = tensor<1x36x16x!quant.uniform>}> {volatile} // CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x36x16x!quant.uniform>) // CHECK: return %3 : tensor<1x36x16xf32> } @@ -366,8 +366,8 @@ func.func @QuantizeSoftmax(tensor<1x6x6x16x!quant.uniform // CHECK: %0 = "tfl.dequantize"(%arg0) -// CHECK: %1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> -// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform>, volatile} +// CHECK: %1 = "tfl.softmax"(%0) <{beta = 1.000000e+00 : f32}> : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> +// CHECK: %2 = "tfl.quantize"(%1) <{qtype = tensor<1x6x6x16x!quant.uniform>}> {volatile} // CHECK: %3 = "tfl.dequantize"(%2) // CHECK: return %3 : tensor<1x6x6x16xf32> } @@ -381,7 +381,7 @@ func.func @QuantizeLogistic(tensor<1x6x6x16x!quant.uniform) -> tensor<1x6x6x16xf32> -// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform>, volatile} +// CHECK: %2 = "tfl.quantize"(%1) <{qtype = tensor<1x6x6x16x!quant.uniform>}> {volatile} // CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> // CHECK: return %3 : tensor<1x6x6x16xf32> } @@ -416,7 +416,7 @@ func.func @QDQNoQuantizeSoftmax(tensor<1x6x6x16x!quant.uniform // QDQ: %0 = "tfl.dequantize"(%arg0) -// QDQ: %1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> +// QDQ: %1 = "tfl.softmax"(%0) <{beta = 1.000000e+00 : f32}> : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> // QDQ-NOT: "tfl.quantize" // QDQ: return %1 : tensor<1x6x6x16xf32> } @@ -429,7 +429,7 @@ func.func @QuantizeL2Norm(%arg0: tensor<1x6x6x16x!quant.uniform>) - // CHECK: %[[in:.*]] = "tfl.dequantize"(%arg0) // CHECK: %[[l2:.*]] = "tfl.l2_normalization"(%[[in]]) -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[l2]]) {qtype = tensor<1x6x6x16x!quant.uniform>, volatile} +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[l2]]) <{qtype = tensor<1x6x6x16x!quant.uniform>}> {volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: return %[[dq]] : tensor<1x6x6x16xf32> } @@ -452,11 +452,11 @@ func.func @QuantizeConcatOperand0ToAll(tensor<1x2x!quant.uniform, tensor<1x2xf32>) -> tensor<2x2xf32> func.return %1 : tensor<2x2xf32> -// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} +// CHECK: %0 = "tfl.quantize"(%arg1) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK: %3 = "tfl.concatenation"(%2, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> -// CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform>, volatile} +// CHECK: %3 = "tfl.concatenation"(%2, %1) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> +// CHECK: %4 = "tfl.quantize"(%3) <{qtype = tensor<2x2x!quant.uniform>}> {volatile} // CHECK: %5 = "tfl.dequantize"(%4) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> // CHECK: return %5 : tensor<2x2xf32> } @@ -468,11 +468,11 @@ func.func @QuantizeConcatOperand1ToAll(tensor<1x2xf32>, tensor<1x2x!quant.unifor %1 = "tfl.concatenation"(%arg0, %0) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> func.return %1 : tensor<2x2xf32> -// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>, volatile} +// CHECK: %0 = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %2 = "tfl.dequantize"(%arg1) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK: %3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> -// CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform>, volatile} +// CHECK: %3 = "tfl.concatenation"(%1, %2) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> +// CHECK: %4 = "tfl.quantize"(%3) <{qtype = tensor<2x2x!quant.uniform>}> {volatile} // CHECK: %5 = "tfl.dequantize"(%4) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> // CHECK: return %5 : tensor<2x2xf32> } @@ -484,12 +484,12 @@ func.func @QuantizeConcatResToAll(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x %1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %1 : tensor<2x2x!quant.uniform> -// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} +// CHECK: %0 = "tfl.quantize"(%arg1) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK: %2 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>, volatile} +// CHECK: %2 = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} // CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK: %4 = "tfl.concatenation"(%3, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> -// CHECK: %5 = "tfl.quantize"(%4) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> +// CHECK: %4 = "tfl.concatenation"(%3, %1) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> +// CHECK: %5 = "tfl.quantize"(%4) <{qtype = tensor<2x2x!quant.uniform>}> : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> // CHECK: return %5 : tensor<2x2x!quant.uniform> } @@ -501,11 +501,11 @@ func.func @QuantizeConcatResToAllNoRequantize(tensor<1x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %2 : tensor<2x2x!quant.uniform> -// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} +// CHECK: %0 = "tfl.quantize"(%arg1) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK: %3 = "tfl.concatenation"(%2, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> -// CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> +// CHECK: %3 = "tfl.concatenation"(%2, %1) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> +// CHECK: %4 = "tfl.quantize"(%3) <{qtype = tensor<2x2x!quant.uniform>}> : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> // CHECK: return %4 : tensor<2x2x!quant.uniform> } @@ -518,13 +518,13 @@ func.func @QuantizeConcatResToAllRequantize(tensor<1x2xf32>, tensor<1x2xf32>) -> %3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %3 : tensor<2x2x!quant.uniform> -// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} +// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} // CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK: %[[Q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%[[Q0]]) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[Q0:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%[[Q0]]) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> // CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> -// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> +// CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> +// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) <{qtype = tensor<2x2x!quant.uniform>}> : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> // CHECK: return %[[Q]] : tensor<2x2x!quant.uniform> } @@ -536,12 +536,12 @@ func.func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %3 : tensor<2x2x!quant.uniform> -// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} +// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} // CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> // CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> -// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> +// CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> +// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) <{qtype = tensor<2x2x!quant.uniform>}> : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> // CHECK: return %[[Q]] : tensor<2x2x!quant.uniform> } @@ -551,8 +551,8 @@ func.func @NotRequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.un %10 = "tfl.concatenation"(%arg0, %9) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform>, tensor<1x73x73x96x!quant.uniform>) -> tensor<1x73x73x160x!quant.uniform> func.return %10 : tensor<1x73x73x160x!quant.uniform> -// CHECK: %[[max:.*]] = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform>) -> tensor<1x73x73x96x!quant.uniform> -// CHECK: %[[cat:.*]] = "tfl.concatenation"(%arg0, %[[max]]) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform>, tensor<1x73x73x96x!quant.uniform>) -> tensor<1x73x73x160x!quant.uniform> +// CHECK: %[[max:.*]] = "tfl.max_pool_2d"(%arg1) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x147x147x96x!quant.uniform>) -> tensor<1x73x73x96x!quant.uniform> +// CHECK: %[[cat:.*]] = "tfl.concatenation"(%arg0, %[[max]]) <{axis = 3 : i32, fused_activation_function = "NONE"}> : (tensor<1x73x73x64x!quant.uniform>, tensor<1x73x73x96x!quant.uniform>) -> tensor<1x73x73x160x!quant.uniform> // CHECK: return %[[cat]] : tensor<1x73x73x160x!quant.uniform> } @@ -577,22 +577,22 @@ func.func @QuantizeChain(tensor<1x224x224x3x!quant.uniform // CHECK: %cst = arith.constant dense<-1.23697901> : tensor<32xf32> -// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>, volatile} +// CHECK: %0 = "tfl.quantize"(%cst) <{qtype = tensor<32x!quant.uniform>}> {volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) // CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) // CHECK: %3 = "tfl.pseudo_qconst"() // CHECK: %4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>) // CHECK: %5 = "tfl.average_pool_2d"(%2) -// CHECK: %6 = "tfl.quantize"(%5) {qtype = tensor<1x224x224x3x!quant.uniform>, volatile} +// CHECK: %6 = "tfl.quantize"(%5) <{qtype = tensor<1x224x224x3x!quant.uniform>}> {volatile} // CHECK: %7 = "tfl.dequantize"(%6) : (tensor<1x224x224x3x!quant.uniform>) // CHECK: %8 = "tfl.conv_2d"(%7, %4, %1) -// CHECK: %9 = "tfl.quantize"(%8) {qtype = tensor<1x112x112x32x!quant.uniform>} +// CHECK: %9 = "tfl.quantize"(%8) <{qtype = tensor<1x112x112x32x!quant.uniform>}> // CHECK: %10 = "tfl.dequantize"(%9) : (tensor<1x112x112x32x!quant.uniform>) // CHECK: %11 = "tfl.reshape"(%10, %{{.*}}) -// CHECK: %12 = "tfl.quantize"(%11) {qtype = tensor<1x36x16x!quant.uniform>, volatile} +// CHECK: %12 = "tfl.quantize"(%11) <{qtype = tensor<1x36x16x!quant.uniform>}> {volatile} // CHECK: %13 = "tfl.dequantize"(%12) : (tensor<1x36x16x!quant.uniform>) // CHECK: %14 = "tfl.softmax"(%13) -// CHECK: %15 = "tfl.quantize"(%14) {qtype = tensor<1x36x16x!quant.uniform>, volatile} +// CHECK: %15 = "tfl.quantize"(%14) <{qtype = tensor<1x36x16x!quant.uniform>}> {volatile} // CHECK: %16 = "tfl.dequantize"(%15) : (tensor<1x36x16x!quant.uniform>) // CHECK: return %16 : tensor<1x36x16xf32> } @@ -603,7 +603,7 @@ func.func @QuantizeConstant() -> tensor<2x3xf32> { func.return %cst : tensor<2x3xf32> // CHECK: %cst = arith.constant dense{{.*}}tensor<2x3xf32> -// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<2x3x!quant.uniform>, volatile} +// CHECK: %0 = "tfl.quantize"(%cst) <{qtype = tensor<2x3x!quant.uniform>}> {volatile} // CHECK: %1 = "tfl.dequantize"(%0) // CHECK: return %1 : tensor<2x3xf32> } @@ -613,7 +613,7 @@ func.func @NotQuantizeNoneType() -> none { %cst = "tfl.no_value"() {value = unit} : () -> none func.return %cst : none -// CHECK-NEXT: %[[cst:.*]] = "tfl.no_value"() {value} : () -> none +// CHECK-NEXT: %[[cst:.*]] = "tfl.no_value"() <{value}> : () -> none // CHECK-NEXT: return %[[cst]] } @@ -623,7 +623,7 @@ func.func @QuantizeZeroSplat() -> tensor<2x3xf32> { func.return %cst : tensor<2x3xf32> // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : tensor<2x3xf32> -// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform>, volatile} +// CHECK-NEXT: "tfl.quantize"(%[[cst]]) <{qtype = tensor<2x3x!quant.uniform>}> {volatile} } // CHECK-LABEL: QuantizeZeroScalar @@ -632,7 +632,7 @@ func.func @QuantizeZeroScalar() -> tensor { func.return %cst : tensor // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : tensor -// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor>, volatile} +// CHECK-NEXT: "tfl.quantize"(%[[cst]]) <{qtype = tensor>}> {volatile} } // CHECK-LABEL: QuantizePositiveSplat @@ -641,7 +641,7 @@ func.func @QuantizePositiveSplat() -> tensor<2x3xf32> { func.return %cst : tensor<2x3xf32> // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<2.540000e+01> : tensor<2x3xf32> -// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform>, volatile} +// CHECK-NEXT: "tfl.quantize"(%[[cst]]) <{qtype = tensor<2x3x!quant.uniform>}> {volatile} } // CHECK-LABEL: QuantizePositiveScalar @@ -650,7 +650,7 @@ func.func @QuantizePositiveScalar() -> tensor { func.return %cst : tensor // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<2.540000e+00> : tensor -// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor>, volatile} +// CHECK-NEXT: "tfl.quantize"(%[[cst]]) <{qtype = tensor>}> {volatile} } // CHECK-LABEL: QuantizeNegativeSplat @@ -659,7 +659,7 @@ func.func @QuantizeNegativeSplat() -> tensor<2x3xf32> { func.return %cst : tensor<2x3xf32> // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<-2.540000e+00> : tensor<2x3xf32> -// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform>, volatile} +// CHECK-NEXT: "tfl.quantize"(%[[cst]]) <{qtype = tensor<2x3x!quant.uniform>}> {volatile} } // CHECK-LABEL: QuantizeNegativeScalar @@ -668,7 +668,7 @@ func.func @QuantizeNegativeScalar() -> tensor { func.return %cst : tensor // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<-2.540000e+01> : tensor -// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor>, volatile} +// CHECK-NEXT: "tfl.quantize"(%[[cst]]) <{qtype = tensor>}> {volatile} } // Make sure biases are not shared. @@ -721,7 +721,7 @@ func.func @QuantizeSharedBiases2( // CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: %[[cst_0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf32> -// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<32x!quant.uniform>, volatile} +// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) <{qtype = tensor<32x!quant.uniform>}> {volatile} // CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]]) // CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]] // CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]]) @@ -746,7 +746,7 @@ func.func @QuantizeSharedBiases3( func.return %3, %7 : tensor<32x!quant.uniform>, tensor<1x56x56x32x!quant.uniform> // CHECK: %[[cst:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf32> -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform>, volatile} +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) <{qtype = tensor<32x!quant.uniform>}> {volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: %[[cst_0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf32> // CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) @@ -796,10 +796,10 @@ func.func @QuantizeWeight(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32 func.return %c : tensor<1x112x112x32xf32> // CHECK: %[[w:.*]] = arith.constant dense<1.000000e+00> : tensor<32x3x3x3xf32> -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.003937007874015748:1>>, volatile} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3x!quant.uniform:f32, 0.003937007874015748:1>> +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.003937007874015748:1>>}> {volatile} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3x!quant.uniform:f32, 0.003937007874015748:1>> // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<32x3x3x3x!quant.uniform:f32, 0.003937007874015748:1>>) -> tensor<32x3x3x3xf32> // CHECK: %[[b:.*]] = arith.constant dense<-1.000000e+00> : tensor<32xf32> -// CHECK: %[[c:.*]] = "tfl.conv_2d"(%arg0, %[[dq]], %[[b]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> +// CHECK: %[[c:.*]] = "tfl.conv_2d"(%arg0, %[[dq]], %[[b]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> // CHECK: return %[[c]] : tensor<1x112x112x32xf32> } @@ -812,7 +812,7 @@ func.func @NoRedundantQuantizeWeight() -> tensor<1x112x112x32xf32> { func.return %dq : tensor<1x112x112x32xf32> // CHECK-NEXT: %[[w:.*]] = arith.constant dense<1.000000e+00> : tensor<1x112x112x32xf32> -// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<1x112x112x32x!quant.uniform>} +// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<1x112x112x32x!quant.uniform>}> // CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK-NEXT: return %[[dq]] : tensor<1x112x112x32xf32> } @@ -849,34 +849,34 @@ func.func @QuantizedCatsAddRequantsTest(%arg0: tensor<1x1xf32>, %arg1: tensor<1x %13 = "tfl.concatenation"(%9, %3) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x3xf32> %14 = "quantfork.stats"(%13) {layerStats = dense<[-0.488159984, 0.398609281]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> func.return %10, %14 : tensor<1x4xf32>, tensor<1x3xf32> -// CHECK-NEXT: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[r0q0:.*]] = "tfl.quantize"(%[[q0]]) {qtype = tensor<1x1x!quant.uniform>} : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[r1q0:.*]] = "tfl.quantize"(%[[q0]]) {qtype = tensor<1x1x!quant.uniform>} : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[q0:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[r0q0:.*]] = "tfl.quantize"(%[[q0]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[r1q0:.*]] = "tfl.quantize"(%[[q0]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[d1q0:.*]] = "tfl.dequantize"(%[[r1q0]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> // CHECK-NEXT: %[[d0q0:.*]] = "tfl.dequantize"(%[[r0q0]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[r0q1:.*]] = "tfl.quantize"(%[[q1]]) {qtype = tensor<1x1x!quant.uniform>} : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[r0q1:.*]] = "tfl.quantize"(%[[q1]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[d0q1:.*]] = "tfl.dequantize"(%[[r0q1]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%arg2) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[r0q2:.*]] = "tfl.quantize"(%[[q2]]) {qtype = tensor<1x1x!quant.uniform>} : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%arg2) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[r0q2:.*]] = "tfl.quantize"(%[[q2]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[d0q2:.*]] = "tfl.dequantize"(%[[r0q2]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%arg3) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[r0q3:.*]] = "tfl.quantize"(%[[q3]]) {qtype = tensor<1x1x!quant.uniform>} : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%arg3) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[r0q3:.*]] = "tfl.quantize"(%[[q3]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[d0q3:.*]] = "tfl.dequantize"(%[[r0q3]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[cat1_0:.*]] = "tfl.concatenation"(%[[d0q1]], %[[d1q0]]) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> -// CHECK-NEXT: %[[qcat1_0:.*]] = "tfl.quantize"(%[[cat1_0]]) {qtype = tensor<1x2x!quant.uniform>, volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK-NEXT: %[[r0qcat1_0:.*]] = "tfl.quantize"(%[[qcat1_0]]) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK-NEXT: %[[cat1_0:.*]] = "tfl.concatenation"(%[[d0q1]], %[[d1q0]]) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> +// CHECK-NEXT: %[[qcat1_0:.*]] = "tfl.quantize"(%[[cat1_0]]) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-NEXT: %[[r0qcat1_0:.*]] = "tfl.quantize"(%[[qcat1_0]]) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> // CHECK-NEXT: %[[d0qcat1_0:.*]] = "tfl.dequantize"(%[[r0qcat1_0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK-NEXT: %[[cat_2_0:.*]] = "tfl.concatenation"(%[[d0q2]], %[[d0q0]]) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> -// CHECK-NEXT: %[[qcat_2_0:.*]] = "tfl.quantize"(%[[cat_2_0]]) {qtype = tensor<1x2x!quant.uniform>, volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK-NEXT: %[[r0qcat_2_0:.*]] = "tfl.quantize"(%[[qcat_2_0]]) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK-NEXT: %[[cat_2_0:.*]] = "tfl.concatenation"(%[[d0q2]], %[[d0q0]]) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> +// CHECK-NEXT: %[[qcat_2_0:.*]] = "tfl.quantize"(%[[cat_2_0]]) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-NEXT: %[[r0qcat_2_0:.*]] = "tfl.quantize"(%[[qcat_2_0]]) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> // CHECK-NEXT: %[[d0qcat_2_0:.*]] = "tfl.dequantize"(%[[r0qcat_2_0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK-NEXT: %[[dqcat_2_0:.*]] = "tfl.dequantize"(%[[qcat_2_0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK-NEXT: %[[cat_2_0_1_0:.*]] = "tfl.concatenation"(%[[dqcat_2_0]], %[[d0qcat1_0]]) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x4xf32> -// CHECK-NEXT: %[[qcat_2_0_1_0:.*]] = "tfl.quantize"(%[[cat_2_0_1_0]]) {qtype = tensor<1x4x!quant.uniform>, volatile} : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> +// CHECK-NEXT: %[[cat_2_0_1_0:.*]] = "tfl.concatenation"(%[[dqcat_2_0]], %[[d0qcat1_0]]) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x4xf32> +// CHECK-NEXT: %[[qcat_2_0_1_0:.*]] = "tfl.quantize"(%[[cat_2_0_1_0]]) <{qtype = tensor<1x4x!quant.uniform>}> {volatile} : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> // CHECK-NEXT: %[[dqcat_2_0_1_0:.*]] = "tfl.dequantize"(%[[qcat_2_0_1_0]]) : (tensor<1x4x!quant.uniform>) -> tensor<1x4xf32> -// CHECK-NEXT: %[[cat_2_0_3:.*]] = "tfl.concatenation"(%[[d0qcat_2_0]], %[[d0q3]]) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x3xf32> -// CHECK-NEXT: %[[qcat_2_0_3:.*]] = "tfl.quantize"(%[[cat_2_0_3]]) {qtype = tensor<1x3x!quant.uniform>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> +// CHECK-NEXT: %[[cat_2_0_3:.*]] = "tfl.concatenation"(%[[d0qcat_2_0]], %[[d0q3]]) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x3xf32> +// CHECK-NEXT: %[[qcat_2_0_3:.*]] = "tfl.quantize"(%[[cat_2_0_3]]) <{qtype = tensor<1x3x!quant.uniform>}> {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> // CHECK-NEXT: %[[dqcat_2_0_3:.*]] = "tfl.dequantize"(%[[qcat_2_0_3]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK-NEXT: return %[[dqcat_2_0_1_0]], %[[dqcat_2_0_3]] : tensor<1x4xf32>, tensor<1x3xf32> } @@ -892,10 +892,10 @@ func.func @TransposePerTensorQuantizationPropagation() -> tensor<2x5xf32> { // QDQ: %[[perm:.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> // QDQ-NEXT: %[[w:.*]] = arith.constant dense<1.000000e+00> : tensor<5x2xf32> - // QDQ-NEXT: %[[qw:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<5x2x!quant.uniform:f32 + // QDQ-NEXT: %[[qw:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<5x2x!quant.uniform:f32 // QDQ-NEXT: %[[dqw:.*]] = "tfl.dequantize"(%[[qw]]) : (tensor<5x2x!quant.uniform:f32 // QDQ-NEXT: %[[tp:.*]] = "tfl.transpose"(%[[dqw]], %[[perm]]) : (tensor<5x2xf32>, tensor<2xi32>) -> tensor<2x5xf32> - // QDQ-NEXT: %[[qtw:.*]] = "tfl.quantize"(%[[tp]]) {qtype = tensor<2x5x!quant.uniform:f32 + // QDQ-NEXT: %[[qtw:.*]] = "tfl.quantize"(%[[tp]]) <{qtype = tensor<2x5x!quant.uniform:f32 // QDQ-NEXT: %[[dqtw:.*]] = "tfl.dequantize"(%[[qtw]]) : (tensor<2x5x!quant.uniform:f32 // QDQ-NEXT: return %[[dqtw]] : tensor<2x5xf32> } @@ -911,10 +911,10 @@ func.func @TransposePerChannelNewQuantDim() -> tensor<2x5xf32> { // QDQ: %[[perm:.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> // QDQ-NEXT: %[[w:.*]] = arith.constant dense<1.000000e+00> : tensor<5x2xf32> -// QDQ-NEXT: %[[qw:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<5x2x!quant.uniform:f32:0 +// QDQ-NEXT: %[[qw:.*]] = "tfl.quantize"(%[[w]]) <{qtype = tensor<5x2x!quant.uniform:f32:0 // QDQ-NEXT: %[[dqw:.*]] = "tfl.dequantize"(%[[qw]]) : (tensor<5x2x!quant.uniform:f32:0 // QDQ-NEXT: %[[tp:.*]] = "tfl.transpose"(%[[dqw]], %[[perm]]) : (tensor<5x2xf32>, tensor<2xi32>) -> tensor<2x5xf32> -// QDQ-NEXT: %[[qtw:.*]] = "tfl.quantize"(%[[tp]]) {qtype = tensor<2x5x!quant.uniform:f32:1 +// QDQ-NEXT: %[[qtw:.*]] = "tfl.quantize"(%[[tp]]) <{qtype = tensor<2x5x!quant.uniform:f32:1 // QDQ-NEXT: %[[dqtw:.*]] = "tfl.dequantize"(%[[qtw]]) : (tensor<2x5x!quant.uniform:f32:1 // QDQ-NEXT: return %[[dqtw]] : tensor<2x5xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant-4bit.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant-4bit.mlir index dca4c21766ee4a..d9e15db9a182a5 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant-4bit.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant-4bit.mlir @@ -12,7 +12,7 @@ func.func @fakeQuantPerChannelForActivation(%arg0: tensor<8x4xf32>) -> (tensor<8 // CHECK: %[[fq:.*]] = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %cst, %cst_0) // The last channel tests the code in quantization utils that expands very small ranges to be at least 1e-6. -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[fq]]) {qtype = tensor<8x4x!quant.uniform>} +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[fq]]) <{qtype = tensor<8x4x!quant.uniform>}> // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: return %[[dq]] } @@ -26,7 +26,7 @@ func.func @fakeQuantForActivation(tensor<8xf32>) -> (tensor<8xf32>) { func.return %0 : tensor<8xf32> // CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) -// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform>} +// CHECK: %1 = "tfl.quantize"(%0) <{qtype = tensor<8x!quant.uniform>}> // CHECK: %2 = "tfl.dequantize"(%1) // CHECK: return %2 } @@ -41,7 +41,7 @@ func.func @fakeQuantForActivationNoDuplication(tensor<8xf32>) -> (tensor<8x!quan func.return %1 : tensor<8x!quant.uniform> // CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) <{narrow_range = false, num_bits = 3 : i64}> -// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform>} +// CHECK: %1 = "tfl.quantize"(%0) <{qtype = tensor<8x!quant.uniform>}> // CHECK: return %1 } @@ -60,7 +60,7 @@ func.func @WrappedFakeQuantFolded() -> tensor<8xf32> { func.return %rst : tensor<8xf32> // CHECK: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) {qtype = tensor<8x!quant.uniform>} +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) <{qtype = tensor<8x!quant.uniform>}> // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32> } @@ -76,7 +76,7 @@ func.func @fakeQuantFolded() -> (tensor<8xf32>) { func.return %rst : tensor<8xf32> // CHECK: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) {qtype = tensor<8x!quant.uniform>} +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) <{qtype = tensor<8x!quant.uniform>}> // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32> } @@ -90,7 +90,7 @@ func.func @fakeQuantFoldedWithoutIdentity() -> (tensor<8xf32>) { func.return %rst : tensor<8xf32> // CHECK: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) {qtype = tensor<8x!quant.uniform>} +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) <{qtype = tensor<8x!quant.uniform>}> // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32> } @@ -108,7 +108,7 @@ func.func @fakeQuantFoldedWithCast() -> (tensor<8xf32>) { func.return %rst : tensor<8xf32> // CHECK: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) {qtype = tensor<8x!quant.uniform>} +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) <{qtype = tensor<8x!quant.uniform>}> // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32> } @@ -198,7 +198,7 @@ func.func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x8x7x16xf3 // CHECK-DAG: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> // CHECK-DAG: %[[CONSTANT0:.*]] = arith.constant dense<0.000000e+00> : tensor<16x3x3x3xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<16x3x3x3x!quant.uniform>} +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) <{qtype = tensor<16x3x3x3x!quant.uniform>}> // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: %[[CONV:.*]] = "tfl.conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]]) // CHECK: return %[[CONV]] @@ -218,7 +218,7 @@ func.func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256 // CHECK-DAG: %[[CONSTANT0:.*]] = arith.constant dense<0.000000e+00> : tensor<16x3x3x3xf32> // CHECK-DAG: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<16x3x3x3x!quant.uniform> // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: %[[CONV:.*]] = "tfl.conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]]) @@ -239,7 +239,7 @@ func.func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x // CHECK-DAG: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<48xf32> // CHECK-DAG: %[[CONSTANT0:.*]] = arith.constant dense<0.000000e+00> : tensor<1x3x3x48xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<1x3x3x48x!quant.uniform>} +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) <{qtype = tensor<1x3x3x48x!quant.uniform>}> // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: %[[CONV:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]]) // CHECK: return %[[CONV]] @@ -259,7 +259,7 @@ func.func @perChannelFakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (t // CHECK-DAG: %[[CONSTANT0:.*]] = arith.constant dense<0.000000e+00> : tensor<1x3x3x48xf32> // CHECK-DAG: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<48xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<1x3x3x48x!quant.uniform>} @@ -285,7 +285,7 @@ func.func @perChannelFakeQuantWithDepthwiseConv2DWithReshape(%arg: tensor<1x160x // CHECK-DAG: %[[CONSTANT0:.*]] = arith.constant dense<0.000000e+00> : tensor<1x3x3x48xf32> // CHECK-DAG: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<48xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<1x3x3x48x!quant.uniform>} @@ -302,7 +302,7 @@ func.func @fakeQuant3BitPerChannelForActivation(%arg0: tensor<8x4xf32>) -> (tens func.return %0 : tensor<8x4xf32> // LOBIT: %[[fq:.*]] = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %cst, %cst_0) -// LOBIT: %[[q:.*]] = "tfl.quantize"(%[[fq]]) {qtype = tensor<8x4x!quant.uniform:f32:1, {1.000000e+00,1.000000e+00:1,2.000000e+00:4,2.000000e+00:3}>>} +// LOBIT: %[[q:.*]] = "tfl.quantize"(%[[fq]]) <{qtype = tensor<8x4x!quant.uniform:f32:1, {1.000000e+00,1.000000e+00:1,2.000000e+00:4,2.000000e+00:3}>>}> // LOBIT: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // LOBIT: return %[[dq]] } @@ -316,7 +316,7 @@ func.func @fakeQuant3BitForActivation(tensor<8xf32>) -> (tensor<8xf32>) { func.return %0 : tensor<8xf32> // LOBIT: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) -// LOBIT: %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform:f32, 2.000000e+00:3>>} +// LOBIT: %1 = "tfl.quantize"(%0) <{qtype = tensor<8x!quant.uniform:f32, 2.000000e+00:3>>}> // LOBIT: %2 = "tfl.dequantize"(%1) // LOBIT: return %2 } @@ -335,7 +335,7 @@ func.func @fakeQuant4BitWithConv2DPerChannel(tensor<256x32x32x3xf32>) -> (tensor // LOBIT-DAG: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<4xf32> // LOBIT-DAG: %[[CONSTANT0:.*]] = arith.constant dense<0.000000e+00> : tensor<4x3x3x3xf32> -// LOBIT: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<4x3x3x3x!quant.uniform:f32:0, {1.000000e+00:1,1.000000e+00:2,1.000000e+00:7,1.000000e+00:15}>>} +// LOBIT: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) <{qtype = tensor<4x3x3x3x!quant.uniform:f32:0, {1.000000e+00:1,1.000000e+00:2,1.000000e+00:7,1.000000e+00:15}>>}> // LOBIT: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // LOBIT: %[[CONV:.*]] = "tfl.conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]]) // LOBIT: return %[[CONV]] diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir index c65cecc188f468..b9fe9310588d77 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir @@ -11,7 +11,7 @@ func.func @fakeQuantPerChannelForActivation(%arg0: tensor<8x4xf32>) -> (tensor<8 func.return %0 : tensor<8x4xf32> // CHECK: %[[fq:.*]] = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %cst, %cst_0) -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[fq]]) {qtype = tensor<8x4x!quant.uniform>} +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[fq]]) <{qtype = tensor<8x4x!quant.uniform>}> // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: return %[[dq]] } @@ -25,7 +25,7 @@ func.func @fakeQuantForActivation(tensor<8xf32>) -> (tensor<8xf32>) { func.return %0 : tensor<8xf32> // CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) -// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform>} +// CHECK: %1 = "tfl.quantize"(%0) <{qtype = tensor<8x!quant.uniform>}> // CHECK: %2 = "tfl.dequantize"(%1) // CHECK: return %2 } @@ -40,7 +40,7 @@ func.func @fakeQuantForActivationNoDuplication(tensor<8xf32>) -> (tensor<8x!quan func.return %1 : tensor<8x!quant.uniform> // CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) <{narrow_range = false, num_bits = 5 : i64}> -// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform>} +// CHECK: %1 = "tfl.quantize"(%0) <{qtype = tensor<8x!quant.uniform>}> // CHECK: return %1 } @@ -59,7 +59,7 @@ func.func @WrappedFakeQuantFolded() -> tensor<8xf32> { func.return %rst : tensor<8xf32> // CHECK: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) {qtype = tensor<8x!quant.uniform>} +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) <{qtype = tensor<8x!quant.uniform>}> // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32> } @@ -75,7 +75,7 @@ func.func @fakeQuantFolded() -> (tensor<8xf32>) { func.return %rst : tensor<8xf32> // CHECK: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) {qtype = tensor<8x!quant.uniform>} +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) <{qtype = tensor<8x!quant.uniform>}> // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32> } @@ -89,7 +89,7 @@ func.func @fakeQuantFoldedWithoutIdentity() -> (tensor<8xf32>) { func.return %rst : tensor<8xf32> // CHECK: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) {qtype = tensor<8x!quant.uniform>} +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) <{qtype = tensor<8x!quant.uniform>}> // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32> } @@ -107,7 +107,7 @@ func.func @fakeQuantFoldedWithCast() -> (tensor<8xf32>) { func.return %rst : tensor<8xf32> // CHECK: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) {qtype = tensor<8x!quant.uniform>} +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) <{qtype = tensor<8x!quant.uniform>}> // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32> } @@ -197,7 +197,7 @@ func.func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x8x7x16xf3 // CHECK-DAG: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> // CHECK-DAG: %[[CONSTANT0:.*]] = arith.constant dense<0.000000e+00> : tensor<16x3x3x3xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<16x3x3x3x!quant.uniform>} +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) <{qtype = tensor<16x3x3x3x!quant.uniform>}> // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: %[[CONV:.*]] = "tfl.conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]]) // CHECK: return %[[CONV]] @@ -217,7 +217,7 @@ func.func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256 // CHECK-DAG: %[[CONSTANT0:.*]] = arith.constant dense<0.000000e+00> : tensor<16x3x3x3xf32> // CHECK-DAG: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<16x3x3x3x!quant.uniform> // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: %[[CONV:.*]] = "tfl.conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]]) @@ -238,7 +238,7 @@ func.func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x // CHECK-DAG: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<48xf32> // CHECK-DAG: %[[CONSTANT0:.*]] = arith.constant dense<0.000000e+00> : tensor<1x3x3x48xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<1x3x3x48x!quant.uniform>} +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) <{qtype = tensor<1x3x3x48x!quant.uniform>}> // CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) // CHECK: %[[CONV:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]]) // CHECK: return %[[CONV]] @@ -258,7 +258,7 @@ func.func @perChannelFakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (t // CHECK-DAG: %[[CONSTANT0:.*]] = arith.constant dense<0.000000e+00> : tensor<1x3x3x48xf32> // CHECK-DAG: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<48xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<1x3x3x48x!quant.uniform>} @@ -284,7 +284,7 @@ func.func @perChannelFakeQuantWithDepthwiseConv2DWithReshape(%arg: tensor<1x160x // CHECK-DAG: %[[CONSTANT0:.*]] = arith.constant dense<0.000000e+00> : tensor<1x3x3x48xf32> // CHECK-DAG: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<48xf32> -// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<1x3x3x48x!quant.uniform>} @@ -301,7 +301,7 @@ func.func @fakeQuant3BitPerChannelForActivation(%arg0: tensor<8x4xf32>) -> (tens func.return %0 : tensor<8x4xf32> // LOBIT: %[[fq:.*]] = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %cst, %cst_0) -// LOBIT: %[[q:.*]] = "tfl.quantize"(%[[fq]]) {qtype = tensor<8x4x!quant.uniform:f32:1, {1.000000e+00,1.000000e+00:1,2.000000e+00:16,2.000000e+00:15}>>} +// LOBIT: %[[q:.*]] = "tfl.quantize"(%[[fq]]) <{qtype = tensor<8x4x!quant.uniform:f32:1, {1.000000e+00,1.000000e+00:1,2.000000e+00:16,2.000000e+00:15}>>}> // LOBIT: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // LOBIT: return %[[dq]] } @@ -315,7 +315,7 @@ func.func @fakeQuant3BitForActivation(tensor<8xf32>) -> (tensor<8xf32>) { func.return %0 : tensor<8xf32> // LOBIT: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) -// LOBIT: %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform:f32, 2.000000e+00:15>>} +// LOBIT: %1 = "tfl.quantize"(%0) <{qtype = tensor<8x!quant.uniform:f32, 2.000000e+00:15>>}> // LOBIT: %2 = "tfl.dequantize"(%1) // LOBIT: return %2 } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 1b323ff3df689b..785cfa2fa2d26f 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -23,13 +23,13 @@ func.func @conv(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<256x3x32x3 // CHECK-DAG: %[[CONSTANT0:.*]] = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> // CHECK-DAG: %[[CONSTANT1:.*]] = arith.constant dense<[{{\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> // CHECK: %0 = "tf.Transpose"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32> -// CHECK: %1 = "tfl.conv_2d"(%arg0, %0, %[[CONSTANT]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x8x7x16xf32> +// CHECK: %1 = "tfl.conv_2d"(%arg0, %0, %[[CONSTANT]]) <{dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}> : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x8x7x16xf32> // CHECK: %2 = "tf.Conv2D" // CHECK: %3 = "tf.Transpose"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32> -// CHECK: %4 = "tfl.conv_2d"(%arg0, %3, %[[CONSTANT]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x8x6x16xf32> +// CHECK: %4 = "tfl.conv_2d"(%arg0, %3, %[[CONSTANT]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32}> : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x8x6x16xf32> // CHECK: %5 = "tf.Pad"(%arg0, %[[CONSTANT1]]) : (tensor<256x32x32x3xf32>, tensor<4x2xi32>) -> tensor<*xf32> // CHECK: %6 = "tf.Transpose"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32> -// CHECK: %7 = "tfl.conv_2d"(%5, %6, %[[CONSTANT]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<*xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x32x32x16xf32> +// CHECK: %7 = "tfl.conv_2d"(%5, %6, %[[CONSTANT]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<*xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x32x32x16xf32> // CHECK: %8 = "tf.Conv2D"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [2, 1, 1, 1]}> {T = "tfdtype$DT_FLOAT"} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> } @@ -50,10 +50,10 @@ func.func @depthwiseConv2D(tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>, tensor< // CHECK-DAG: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<12xf32> // CHECK-DAG: %[[CONSTANT0:.*]] = arith.constant dense<[1, 3, 3, 12]> : tensor<4xi32> // CHECK: %0 = "tf.Reshape"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x4xf32>, tensor<4xi32>) -> tensor<1x3x3x12xf32> -// CHECK: %1 = "tfl.depthwise_conv_2d"(%arg0, %0, %[[CONSTANT]]) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32> +// CHECK: %1 = "tfl.depthwise_conv_2d"(%arg0, %0, %[[CONSTANT]]) <{depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}> : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32> // CHECK: %2 = "tf.DepthwiseConv2dNative" // CHECK: %3 = "tf.Reshape"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x4xf32>, tensor<4xi32>) -> tensor<1x3x3x12xf32> -// CHECK: %4 = "tfl.depthwise_conv_2d"(%arg0, %3, %[[CONSTANT]]) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32> +// CHECK: %4 = "tfl.depthwise_conv_2d"(%arg0, %3, %[[CONSTANT]]) <{depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32}> : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32> // CHECK: %5 = "tf.DepthwiseConv2dNative" } @@ -145,7 +145,7 @@ func.func @QDQsFollowedByTranspose(tensor<1x2xf32>) -> (tensor<2x1xf32>) { // CHECK: %cst = arith.constant // CHECK: %[[trans:.*]] = "tf.Transpose" // CHECK-SAME: -> tensor<2x1xf32> -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[trans]]) {qtype = tensor<2x1x!quant.uniform>} +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[trans]]) <{qtype = tensor<2x1x!quant.uniform>}> // CHECK-SAME: -> tensor<2x1x!quant.uniform> // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK-SAME: -> tensor<2x1xf32> @@ -164,7 +164,7 @@ func.func @QDQFollowedByReshape(tensor<1x2xf32>) -> (tensor<2x1xf32>) { // CHECK: %cst = arith.constant // CHECK: %[[rs:.*]] = "tf.Reshape" // CHECK-SAME: -> tensor<2x1xf32> -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[rs]]) {qtype = tensor<2x1x!quant.uniform>} +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[rs]]) <{qtype = tensor<2x1x!quant.uniform>}> // CHECK-SAME: -> tensor<2x1x!quant.uniform> // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK-SAME: -> tensor<2x1xf32> @@ -503,7 +503,7 @@ func.func @xla_conv_v2(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> { func.return %4 : tensor<4x8x8x16xf32> // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> // CHECK-DAG: %[[CST0:.*]] = arith.constant dense<1.000000e+00> : tensor<16x3x3x16xf32> - // CHECK: %[[RES:.*]] = "tfl.conv_2d"(%arg0, %[[CST0]], %[[CST]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<4x8x8x16xf32>, tensor<16x3x3x16xf32>, tensor<16xf32>) -> tensor<4x8x8x16xf32> + // CHECK: %[[RES:.*]] = "tfl.conv_2d"(%arg0, %[[CST0]], %[[CST]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<4x8x8x16xf32>, tensor<16x3x3x16xf32>, tensor<16xf32>) -> tensor<4x8x8x16xf32> // CHECK: return %[[RES]] } @@ -661,7 +661,7 @@ func.func @QuantDequantTranspose(%arg0: tensor<2x3xf32>) -> (tensor<2x4xf32>) { // CHECK-LABEL: QuantDequantTranspose // CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<1.00392163> : tensor<3x4xf32> - // CHECK: %[[QUANT:.*]] = "tfl.quantize"(%[[CST_0]]) {qtype = tensor<3x4x!quant.uniform>} : (tensor<3x4xf32>) -> tensor<3x4x!quant.uniform> + // CHECK: %[[QUANT:.*]] = "tfl.quantize"(%[[CST_0]]) <{qtype = tensor<3x4x!quant.uniform>}> : (tensor<3x4xf32>) -> tensor<3x4x!quant.uniform> // CHECK: %[[DEQUANT:.*]] = "tfl.dequantize"(%[[QUANT]]) : (tensor<3x4x!quant.uniform>) -> tensor<3x4xf32> // CHECK: %[[TRANSPOSE:.*]] = "tf.Transpose"(%[[DEQUANT]], %[[CST]]) : (tensor<3x4xf32>, tensor) -> tensor<*xf32> // CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %[[TRANSPOSE]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = true}> : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<2x4xf32> @@ -675,7 +675,7 @@ func.func @GroupConv(%arg0: tensor, %arg1: tensor<1x3x2x14xf32>) // CHECK-DAG: %[[CONSTANT:.*]] = arith.constant dense<0.000000e+00> : tensor<14xf32> // CHECK-DAG: %[[CONSTANT0:.*]] = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> // CHECK: %0 = "tf.Transpose"(%arg1, %[[CONSTANT0]]) : (tensor<1x3x2x14xf32>, tensor<4xi32>) -> tensor<14x1x3x2xf32> - // CHECK: %1 = "tfl.conv_2d"(%arg0, %0, %[[CONSTANT]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 5 : i32} : (tensor, tensor<14x1x3x2xf32>, tensor<14xf32>) -> tensor + // CHECK: %1 = "tfl.conv_2d"(%arg0, %0, %[[CONSTANT]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 5 : i32}> : (tensor, tensor<14x1x3x2xf32>, tensor<14xf32>) -> tensor } func.func @UnsupportedGroupConv_UnrankedTensorType(%arg0: tensor<*xf32>, %arg1: tensor<1x3x2x14xf32>) -> (tensor) { @@ -704,4 +704,80 @@ func.func @RedundantShapeOp(%shape: tensor, %fill: tensor) -> (tenso // CHECK-LABEL: RedundantShapeOp // CHECK-NOT: "tf.Shape" } + +// CHECK-LABEL: @MoveTransposeAcrossPerChannelQuant +func.func @MoveTransposeAcrossPerChannelQuant(%arg0 : tensor<1x224x224x3xf32>) -> tensor<1x112x112x6xf32> { + %cst = "tf.Const"() <{value = dense<6.0> : tensor<6x3x7x7xf32>}> : () -> tensor<6x3x7x7xf32> + %cst_14 = "tf.Const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi64>}> : () -> tensor<4xi64> + %126 = "tfl.quantize"(%cst) {qtype = tensor<6x3x7x7x!quant.uniform:f32:0, {1.412750e-03,3.503970e-04,2.441410e-04,3.823330e-04,2.441410e-04,8.950800e-04}>>} : (tensor<6x3x7x7xf32>) -> tensor<6x3x7x7x!quant.uniform:f32:0, {1.412750e-03,3.503970e-04,2.441410e-04,3.823330e-04,2.441410e-04,8.950800e-04}>> + %127 = "tfl.dequantize"(%126) : (tensor<6x3x7x7x!quant.uniform:f32:0, {1.412750e-03,3.503970e-04,2.441410e-04,3.823330e-04,2.441410e-04,8.950800e-04}>>) -> tensor<6x3x7x7xf32> + %129 = "tf.Transpose"(%127, %cst_14) : (tensor<6x3x7x7xf32>, tensor<4xi64>) -> tensor<7x7x3x6xf32> + %130 = "tf.Conv2D"(%arg0, %129) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [0, 0, 3, 3, 3, 3, 0, 0], padding = "EXPLICIT", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x224x224x3xf32>, tensor<7x7x3x6xf32>) -> tensor<1x112x112x6xf32> + return %130 : tensor<1x112x112x6xf32> + // CHECK: %cst = arith.constant dense<6.000000e+00> : tensor<6x7x7x3xf32> + // CHECK: %cst_0 = arith.constant dense<0.000000e+00> : tensor<6xf32> + // CHECK: %cst_1 = arith.constant dense<{{\[\[}}0, 0], [3, 3], [3, 3], [0, 0]]> : tensor<4x2xi32> + // CHECK: %0 = "tf.Pad"(%arg0, %cst_1) : (tensor<1x224x224x3xf32>, tensor<4x2xi32>) -> tensor<*xf32> + // CHECK: %1 = "tfl.quantize"(%cst) <{qtype = tensor<6x7x7x3x!quant.uniform:f32:0, {1.412750e-03,3.503970e-04,2.441410e-04,3.823330e-04,2.441410e-04,8.950800e-04}>>}> : (tensor<6x7x7x3xf32>) -> tensor<6x7x7x3x!quant.uniform:f32:0, {1.412750e-03,3.503970e-04,2.441410e-04,3.823330e-04,2.441410e-04,8.950800e-04}>> + // CHECK: %2 = "tfl.dequantize"(%1) : (tensor<6x7x7x3x!quant.uniform:f32:0, {1.412750e-03,3.503970e-04,2.441410e-04,3.823330e-04,2.441410e-04,8.950800e-04}>>) -> tensor<6x7x7x3xf32> + // CHECK: %3 = "tfl.conv_2d"(%0, %2, %cst_0) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<*xf32>, tensor<6x7x7x3xf32>, tensor<6xf32>) -> tensor<1x112x112x6xf32> + // CHECK: return %3 : tensor<1x112x112x6xf32> +} + +// CHECK-LABEL: @FoldDoubleTranspose +func.func @FoldDoubleTranspose(%arg0: tensor<1x4x1440x256xf32>) -> tensor<1x1440x256x4xf32> { + %cst_12 = arith.constant dense<[0, 1, 3, 2]> : tensor<4xi32> + %cst_18 = arith.constant dense<[0, 2, 1, 3]> : tensor<4xi32> + %2112 = "tf.Transpose"(%arg0, %cst_18) : (tensor<1x4x1440x256xf32>, tensor<4xi32>) -> tensor<1x1440x4x256xf32> + %2114 = "tf.Transpose"(%2112, %cst_12) : (tensor<1x1440x4x256xf32>, tensor<4xi32>) -> tensor<1x1440x256x4xf32> + return %2114 : tensor<1x1440x256x4xf32> + // CHECK-DAG: %cst = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32> + // CHECK: %0 = "tf.Transpose"(%arg0, %cst) : (tensor<1x4x1440x256xf32>, tensor<4xi32>) -> tensor<1x1440x256x4xf32> + // CHECK: return %0 +} + +// CHECK-LABEL: @FoldMultpleTranspose +func.func @FoldMultpleTranspose(%arg0: tensor<1x4x1440x256xf32>) -> tensor<1x256x4x1440xf32> { + %cst_11 = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32> + %cst_12 = arith.constant dense<[0, 1, 3, 2]> : tensor<4xi32> + %cst_18 = arith.constant dense<[0, 2, 1, 3]> : tensor<4xi32> + %2112 = "tf.Transpose"(%arg0, %cst_11) : (tensor<1x4x1440x256xf32>, tensor<4xi32>) -> tensor<1x1440x256x4xf32> + %2113 = "tf.Transpose"(%2112, %cst_18) : (tensor<1x1440x256x4xf32>, tensor<4xi32>) -> tensor<1x256x1440x4xf32> + %2114 = "tf.Transpose"(%2113, %cst_12) : (tensor<1x256x1440x4xf32>, tensor<4xi32>) -> tensor<1x256x4x1440xf32> + return %2114 : tensor<1x256x4x1440xf32> + // CHECK-DAG: %cst = arith.constant dense<[0, 3, 1, 2]> : tensor<4xi32> + // CHECK: %0 = "tf.Transpose"(%arg0, %cst) : (tensor<1x4x1440x256xf32>, tensor<4xi32>) -> tensor<1x256x4x1440xf32> + // CHECK: return %0 +} + +// CHECK-LABEL @FoldTrivialReshapeIntoTranspose +func.func @FoldTrivialReshapeIntoTranspose(%arg: tensor<2x1x3x3xf32>) -> tensor<1x3x3x2xf32> { + %cst = arith.constant dense<[1, 3, 3, 2]> : tensor<4xi32> + %cst_2 = arith.constant dense<[2, 3, 0, 1]> : tensor<4xi32> + %2 = "tf.Transpose"(%arg, %cst_2) : (tensor<2x1x3x3xf32>, tensor<4xi32>) -> tensor<3x3x2x1xf32> + %3 = "tf.Reshape"(%2, %cst) : (tensor<3x3x2x1xf32>, tensor<4xi32>) -> tensor<1x3x3x2xf32> + return %3: tensor<1x3x3x2xf32> + // CHECK: %cst = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32> + // CHECK: %0 = "tf.Transpose"(%arg0, %cst) : (tensor<2x1x3x3xf32>, tensor<4xi32>) -> tensor<1x3x3x2xf32> + // CHECK: return %0 : tensor<1x3x3x2xf32> +} + +// CHECK-LABEL: @MoveTransposeAcrossDepthwiseConvPerChannelQuant +func.func @MoveTransposeAcrossDepthwiseConvPerChannelQuant(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> { + %cst = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<2xf32> + %cst_1 = arith.constant dense<6.000000e+00> : tensor<2x1x3x3xf32> + %0 = "tfl.quantize"(%cst_1) {qtype = tensor<2x1x3x3x!quant.uniform:f32:0, {6.587140e-03,1.888450e-02}>>} : (tensor<2x1x3x3xf32>) -> tensor<2x1x3x3x!quant.uniform:f32:0, {6.587140e-03,1.888450e-02}>> + %1 = "tfl.dequantize"(%0) : (tensor<2x1x3x3x!quant.uniform:f32:0, {6.587140e-03,1.888450e-02}>>) -> tensor<2x1x3x3xf32> + %2 = "tf.Transpose"(%1, %cst) : (tensor<2x1x3x3xf32>, tensor<4xi32>) -> tensor<1x3x3x2xf32> + %3 = "tfl.depthwise_conv_2d"(%arg0, %2, %cst_0) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> + return %3 : tensor<1x112x112x2xf32> + // CHECK: %cst = arith.constant dense<0.000000e+00> : tensor<2xf32> + // CHECK: %cst_0 = arith.constant dense<6.000000e+00> : tensor<1x3x3x2xf32> + // CHECK: %0 = "tfl.quantize"(%cst_0) <{qtype = tensor<1x3x3x2x!quant.uniform:f32:3, {6.587140e-03,1.888450e-02}>>}> : (tensor<1x3x3x2xf32>) -> tensor<1x3x3x2x!quant.uniform:f32:3, {6.587140e-03,1.888450e-02}>> + // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x3x3x2x!quant.uniform:f32:3, {6.587140e-03,1.888450e-02}>>) -> tensor<1x3x3x2xf32> + // CHECK: %2 = "tfl.depthwise_conv_2d"(%arg0, %1, %cst) <{depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> + // CHECK: return %2 : tensor<1x112x112x2xf32> +} + } diff --git a/tensorflow/compiler/mlir/lite/tests/push-tpose-through-ewise.mlir b/tensorflow/compiler/mlir/lite/tests/push-tpose-through-ewise.mlir index f5a6d68c6f7ec3..a5da33ca90191b 100644 --- a/tensorflow/compiler/mlir/lite/tests/push-tpose-through-ewise.mlir +++ b/tensorflow/compiler/mlir/lite/tests/push-tpose-through-ewise.mlir @@ -95,7 +95,7 @@ func.func @pushTposeBcastNoChange(%arg0: tensor<2x3x4x1xf32>) -> tensor<5x2x3x4x // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<5x2x3x4xf32> // CHECK: %cst_0 = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> // CHECK: %0 = "tfl.transpose"(%arg0, %cst_0) : (tensor<2x3x4x1xf32>, tensor<4xi32>) -> tensor<1x2x3x4xf32> -// CHECK: %1 = tfl.add(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x2x3x4xf32>, tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> +// CHECK: %1 = tfl.add(%0, %cst) <{fused_activation_function = "NONE"}> : (tensor<1x2x3x4xf32>, tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> // ----- @@ -110,7 +110,7 @@ func.func @doubleTposeOneBroadcastInput(%arg0: tensor<2x3x4x1xf32>, %arg1: tenso } // CHECK: %cst = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> -// CHECK: %0 = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<2x3x4x1xf32>, tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> +// CHECK: %0 = tfl.add(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<2x3x4x1xf32>, tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> // CHECK: %1 = "tfl.transpose"(%0, %cst) : (tensor<2x3x4x5xf32>, tensor<4xi32>) -> tensor<5x2x3x4xf32> // CHECK: return %1 : tensor<5x2x3x4xf32> @@ -145,7 +145,7 @@ func.func @pushTposeBcastCstInput(%arg0: tensor<2x3x4x5xf32>) -> tensor<5x2x3x4x // CHECK: %cst = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> // CHECK: %cst_0 = arith.constant dense<1.000000e+00> : tensor<2x3x4x1xf32> -// CHECK: %0 = tfl.add(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<2x3x4x5xf32>, tensor<2x3x4x1xf32>) -> tensor<2x3x4x5xf32> +// CHECK: %0 = tfl.add(%arg0, %cst_0) <{fused_activation_function = "NONE"}> : (tensor<2x3x4x5xf32>, tensor<2x3x4x1xf32>) -> tensor<2x3x4x5xf32> // CHECK: %1 = "tfl.transpose"(%0, %cst) : (tensor<2x3x4x5xf32>, tensor<4xi32>) -> tensor<5x2x3x4xf32> // ----- @@ -161,7 +161,7 @@ func.func @pushTposeBcastScalarCstInput(%arg0: tensor<2x3x4x5xf32>) -> tensor<5x // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor // CHECK: %cst_0 = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> -// CHECK: %0 = tfl.add(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2x3x4x5xf32>, tensor) -> tensor<2x3x4x5xf32> +// CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2x3x4x5xf32>, tensor) -> tensor<2x3x4x5xf32> // CHECK: %1 = "tfl.transpose"(%0, %cst_0) : (tensor<2x3x4x5xf32>, tensor<4xi32>) -> tensor<5x2x3x4xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range-float16.mlir b/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range-float16.mlir new file mode 100644 index 00000000000000..5e0599560975fb --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range-float16.mlir @@ -0,0 +1,78 @@ +// RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="enable-float16-quantization" -tfl-quantize="enable-dynamic-range-quantization=true" | FileCheck --check-prefix=CHECK %s + +// CHECK-LABEL: QuantizeUnidirectionalLstm +func.func @QuantizeUnidirectionalLstm(%arg0: tensor<1x2x3xf32>) -> (tensor<1x2x3xf32>) { + %1 = "tfl.pseudo_const"() {value = dense<[[0.1]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %2 = "tfl.pseudo_const"() {value = dense<[[0.2]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %3 = "tfl.pseudo_const"() {value = dense<[[0.3]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %4 = "tfl.pseudo_const"() {value = dense<[[0.4]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %5 = "tfl.pseudo_const"() {value = dense<[[0.5]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %6 = "tfl.pseudo_const"() {value = dense<[[0.6]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %7 = "tfl.pseudo_const"() {value = dense<[[0.7]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %8 = "tfl.pseudo_const"() {value = dense<[[0.8]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %9 = "tfl.no_value"() {value} : () -> none + %10 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %11 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %recurrent_input = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %cell_input = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %16 = "tfl.unidirectional_sequence_lstm"( + %arg0, + %1, %2, %3, %4, + %5, %6, %7, %8, + %9, %9, %9, + %10, %11, + %10, %10, + %9, %9, + %recurrent_input, %cell_input, + %9, %9, %9, %9) { + cell_clip = 1.000000e+01 : f32, + fused_activation_function = "TANH", + proj_clip = 0.000000e+00 : f32, + time_major = false} : ( + tensor<1x2x3xf32>, + tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, + tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, + none, none, none, + tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, + none, none, + tensor<1x3xf32>, tensor<1x3xf32>, + none, none, none, none) -> tensor<1x2x3xf32> + %17 = "quantfork.stats"(%16) {layerStats = dense<[-0.1, 0.1]> : tensor<2xf32>} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + func.return %17 : tensor<1x2x3xf32> + + // CHECK: %[[NONE:.*]] = "tfl.no_value"() <{value}> : () -> none + // CHECK: %[[DQ_1:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1xf16>) -> tensor<1x1xf32> + // CHECK: %[[DQ_2:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1xf16>) -> tensor<1x1xf32> + // CHECK: %[[DQ_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1xf16>) -> tensor<1x1xf32> + // CHECK: %[[DQ_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1xf16>) -> tensor<1x1xf32> + // CHECK: %[[DQ_5:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1xf16>) -> tensor<1x1xf32> + // CHECK: %[[DQ_6:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1xf16>) -> tensor<1x1xf32> + // CHECK: %[[DQ_7:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1xf16>) -> tensor<1x1xf32> + // CHECK: %[[DQ_8:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1xf16>) -> tensor<1x1xf32> + // CHECK: %[[DQ_9:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3xf16>) -> tensor<3xf32> + // CHECK: %[[DQ_10:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3xf16>) -> tensor<3xf32> + // CHECK: %[[DQ_11:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x3xf16>) -> tensor<1x3xf32> + // CHECK: %[[DQ_12:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x3xf16>) -> tensor<1x3xf32> + // CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"( + // CHECK-SAME: %arg0, + // CHECK-SAME: %[[DQ_1]], %[[DQ_2]], %[[DQ_3]], %[[DQ_4]], + // CHECK-SAME: %[[DQ_5]], %[[DQ_6]], %[[DQ_7]], %[[DQ_8]], + // CHECK-SAME: %[[NONE]], %[[NONE]], %[[NONE]], + // CHECK-SAME: %[[DQ_9]], %[[DQ_10]], %[[DQ_9]], %[[DQ_9]], + // CHECK-SAME: %[[NONE]], %[[NONE]], + // CHECK-SAME: %[[DQ_11]], %[[DQ_12]], + // CHECK-SAME: %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]]) <{ + // CHECK-SAME: cell_clip = 1.000000e+01 : f32, + // CHECK-SAME: fused_activation_function = "TANH", + // CHECK-SAME: proj_clip = 0.000000e+00 : f32, + // CHECK-SAME: time_major = false}> : ( + // CHECK-SAME: tensor<1x2x3xf32>, + // CHECK-SAME: tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, + // CHECK-SAME: tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, + // CHECK-SAME: none, none, none, + // CHECK-SAME: tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, + // CHECK-SAME: none, none, + // CHECK-SAME: tensor<1x3xf32>, tensor<1x3xf32>, + // CHECK-SAME: none, none, none, none) + // CHECK-SAME: -> tensor<1x2x3xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir b/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir index ad4ff5a129f4a2..47a2947692d1eb 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir @@ -18,39 +18,39 @@ func.func @QuantizeConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x64 func.return %conv : tensor<1x112x112x64xf32> // CHECK: %[[b:.*]] = arith.constant dense<-1.23697901> : tensor<64xf32> -// CHECK: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32:0, { -// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[w]], %[[b]]) { +// CHECK: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32:0, { +// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[w]], %[[b]]) <{ // CHECK-NOT: asymmetric_quantize_inputs = true // CHECK-SAME: dilation_h_factor = 1 : i32 // CHECK: return %[[conv:.*]] // PerTensor: %[[b:.*]] = arith.constant dense<-1.23697901> : tensor<64xf32> -// PerTensor: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> -// PerTensor: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[w]], %[[b]]) { +// PerTensor: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> +// PerTensor: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[w]], %[[b]]) <{ // PerTensor-NOT: asymmetric_quantize_inputs = true // PerTensor-SAME: dilation_h_factor = 1 : i32 // PerTensor: return %[[conv:.*]] // PerChannelWeightOnly: %[[b:.*]] = arith.constant dense<-1.23697901> : tensor<64xf32> -// PerChannelWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32:0, { +// PerChannelWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32:0, { // PerChannelWeightOnly: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<64x3x3x3x!quant.uniform:f32:0, { -// PerChannelWeightOnly: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w]], %[[b]]) { +// PerChannelWeightOnly: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w]], %[[b]]) <{ // PerChannelWeightOnly-NOT: asymmetric_quantize_inputs = true // PerChannelWeightOnly-SAME: dilation_h_factor = 1 : i32 // PerChannelWeightOnly: return %[[conv:.*]] // PerTensorWeightOnly: %[[b:.*]] = arith.constant dense<-1.23697901> : tensor<64xf32> -// PerTensorWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> +// PerTensorWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> // PerTensorWeightOnly: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> -// PerTensorWeightOnly: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w]], %[[b]]) { +// PerTensorWeightOnly: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w]], %[[b]]) <{ // PerTensorWeightOnly-NOT: asymmetric_quantize_inputs = true // PerTensorWeightOnly-SAME: dilation_h_factor = 1 : i32 // PerTensorWeightOnly: return %[[conv:.*]] // BLOCK: %[[b:.*]] = arith.constant dense<-1.23697901> : tensor<64xf32> -// BLOCK: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> +// BLOCK: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> // BLOCK: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> -// BLOCK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w]], %[[b]]) { +// BLOCK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w]], %[[b]]) <{ // BLOCK: return %[[conv:.*]] } @@ -63,15 +63,15 @@ func.func @QuantizeDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1 func.return %dconv : tensor<1x112x112x64xf32> // CHECK: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<64xf32> -// CHECK: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32:3, {1.000000e+00,1.000000e+00,1.000000e+00} -// CHECK: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[w]], %[[b]]) { +// CHECK: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32:3, {1.000000e+00,1.000000e+00,1.000000e+00} +// CHECK: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[w]], %[[b]]) <{ // CHECK-NOT: asymmetric_quantize_inputs = true // CHECK-SAME: depth_multiplier = 4 : i32 // CHECK: return %[[dconv:.*]] // PerTensor: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<64xf32> -// PerTensor: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> -// PerTensor: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[w]], %[[b]]) { +// PerTensor: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> +// PerTensor: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[w]], %[[b]]) <{ // PerTensor-NOT: asymmetric_quantize_inputs = true // PerTensor-SAME: depth_multiplier = 4 : i32 // PerTensor: return %[[dconv:.*]] @@ -88,31 +88,31 @@ func.func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x11 func.return %fc : tensor<1x112x112x512xf32> // CHECK: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<512xf32> -// CHECK: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00 -// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[w]], %[[b]]) { +// CHECK: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00 +// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[w]], %[[b]]) <{ // CHECK-NOT: fused_activation_function = "NONE", // CHECK-SAME: asymmetric_quantize_inputs = true, // CHECK: return %[[fc:.*]] // PerTensor: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<512xf32> -// PerTensor: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>> -// PerTensor: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[w]], %[[b]]) { +// PerTensor: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>> +// PerTensor: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[w]], %[[b]]) <{ // PerTensor-NOT: fused_activation_function = "NONE", // PerTensor-SAME: asymmetric_quantize_inputs = true, // PerTensor: return %[[fc:.*]] // PerChannelWeightOnly: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<512xf32> -// PerChannelWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00 +// PerChannelWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00 // PerChannelWeightOnly: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00 -// PerChannelWeightOnly: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq_w]], %[[b]]) { +// PerChannelWeightOnly: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq_w]], %[[b]]) <{ // PerChannelWeightOnly-NOT: fused_activation_function = "NONE", // PerChannelWeightOnly-SAME: asymmetric_quantize_inputs = true, // PerChannelWeightOnly: return %[[fc:.*]] // PerTensorWeightOnly: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<512xf32> -// PerTensorWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>> +// PerTensorWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>> // PerTensorWeightOnly: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<512x12x!quant.uniform:f32, 1.000000e+00>> -// PerTensorWeightOnly: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq_w]], %[[b]]) { +// PerTensorWeightOnly: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq_w]], %[[b]]) <{ // PerTensorWeightOnly-NOT: fused_activation_function = "NONE", // PerTensorWeightOnly-SAME: asymmetric_quantize_inputs = true, // PerTensorWeightOnly: return %[[fc:.*]] @@ -126,13 +126,13 @@ func.func @QuantizeMatmulWithActConst(%arg0: tensor<1x3x3x512xf32>) -> tensor<1x %mm = "tfl.batch_matmul"(%arg0, %w) {adj_x = false, adj_y = false} : (tensor<1x3x3x512xf32>, tensor<512x12xf32>) -> tensor<1x3x3x12xf32> func.return %mm : tensor<1x3x3x12xf32> -// CHECK: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>>, -// CHECK: %[[mm:.*]] = "tfl.batch_matmul"(%arg0, %[[w]]) {adj_x = false, adj_y = false +// CHECK: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>>, +// CHECK: %[[mm:.*]] = "tfl.batch_matmul"(%arg0, %[[w]]) <{adj_x = false, adj_y = false // CHECK-SAME: , asymmetric_quantize_inputs = true // CHECK: return %[[mm:.*]] -// PerTensor: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>>, -// PerTensor: %[[mm:.*]] = "tfl.batch_matmul"(%arg0, %[[w]]) {adj_x = false, adj_y = false +// PerTensor: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>>, +// PerTensor: %[[mm:.*]] = "tfl.batch_matmul"(%arg0, %[[w]]) <{adj_x = false, adj_y = false // PerTensor-SAME: , asymmetric_quantize_inputs = true // PerTensor: return %[[mm:.*]] } @@ -148,33 +148,33 @@ func.func @QuantizeTransposeConvWeightOnly(%arg0: tensor<32x4x4x128xf32>, %arg1: func.return %tconv : tensor<1x32x42x128xf32> // CHECK: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<1x32x42x128xf32> -// CHECK: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>> +// CHECK: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>> // CHECK: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>>) -> tensor<1x32x42x128xf32> -// CHECK: %[[tconv:.*]] = "tfl.transpose_conv"(%arg1, %[[dq_w]], %arg0, %[[b]]) { +// CHECK: %[[tconv:.*]] = "tfl.transpose_conv"(%arg1, %[[dq_w]], %arg0, %[[b]]) <{ // CHECK-NOT: asymmetric_quantize_inputs = true // CHECK-SAME: padding = "SAME" // CHECK: return %[[tconv:.*]] // PerTensor: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<1x32x42x128xf32> -// PerTensor: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>> +// PerTensor: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>> // PerTensor: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32> -// PerTensor: %[[tconv:.*]] = "tfl.transpose_conv"(%arg1, %[[dq_w]], %arg0, %[[b]]) { +// PerTensor: %[[tconv:.*]] = "tfl.transpose_conv"(%arg1, %[[dq_w]], %arg0, %[[b]]) <{ // PerTensor-NOT: asymmetric_quantize_inputs = true // PerTensor-SAME: padding = "SAME" // PerTensor: return %[[tconv:.*]] // PerChannelWeightOnly: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<1x32x42x128xf32> -// PerChannelWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>> +// PerChannelWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>> // PerChannelWeightOnly: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>>) -> tensor<1x32x42x128xf32> -// PerChannelWeightOnly: %[[tconv:.*]] = "tfl.transpose_conv"(%arg1, %[[dq_w]], %arg0, %[[b]]) { +// PerChannelWeightOnly: %[[tconv:.*]] = "tfl.transpose_conv"(%arg1, %[[dq_w]], %arg0, %[[b]]) <{ // PerChannelWeightOnly-NOT: asymmetric_quantize_inputs = true // PerChannelWeightOnly-SAME: padding = "SAME" // PerChannelWeightOnly: return %[[tconv:.*]] // PerTensorWeightOnly: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<1x32x42x128xf32> -// PerTensorWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>> +// PerTensorWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>> // PerTensorWeightOnly: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32> -// PerTensorWeightOnly: %[[tconv:.*]] = "tfl.transpose_conv"(%arg1, %[[dq_w]], %arg0, %[[b]]) { +// PerTensorWeightOnly: %[[tconv:.*]] = "tfl.transpose_conv"(%arg1, %[[dq_w]], %arg0, %[[b]]) <{ // PerTensorWeightOnly-NOT: asymmetric_quantize_inputs = true // PerTensorWeightOnly-SAME: padding = "SAME" // PerTensorWeightOnly: return %[[tconv:.*]] @@ -188,12 +188,12 @@ func.func @QuantizeGatherWeightOnly(%arg0: tensor<3xi32>) -> tensor<3x3x3x3xf32> %emb_s = "quantfork.stats"(%emb) {layerStats = dense<[0.000000e+00, 1.000000e+01]> : tensor<2xf32>} : (tensor<3x3x3x3xf32>) -> tensor<3x3x3x3xf32> func.return %emb_s : tensor<3x3x3x3xf32> -// CHECK: %[[q_w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> +// CHECK: %[[q_w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> // CHECK: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>) -> tensor<64x3x3x3xf32> // CHECK: %[[emb:.*]] = "tfl.gather"(%[[dq_w]], %arg0) // CHECK: return %[[emb:.*]] -// PerTensor: %[[q_w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> +// PerTensor: %[[q_w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> // PerTensor: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>) -> tensor<64x3x3x3xf32> // PerTensor: %[[emb:.*]] = "tfl.gather"(%[[dq_w]], %arg0) // PerTensor: return %[[emb:.*]] @@ -209,16 +209,16 @@ func.func @QuantizeCustomOp(%arg0: tensor<1x1x1x1xf32>) -> tensor<*xf32> attribu func.return %custom : tensor<*xf32> // CHECK: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<1024x1x1x1xf32> -// CHECK: %[[custom:.*]] = "tfl.custom"(%arg0, %[[w:.*]]) {custom_code = "CustomTestOp", custom_option = #tfl} +// CHECK: %[[custom:.*]] = "tfl.custom"(%arg0, %[[w:.*]]) <{custom_code = "CustomTestOp", custom_option = #tfl}> // CHECK: return %[[custom:.*]] -// CustomOpWeightOnly: %[[q_w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>> +// CustomOpWeightOnly: %[[q_w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>> // CustomOpWeightOnly: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w:.*]]) : (tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>>) -> tensor<1024x1x1x1xf32> -// CustomOpWeightOnly: %[[custom:.*]] = "tfl.custom"(%arg0, %[[dq_w:.*]]) {custom_code = "CustomTestOp", custom_option = #tfl} +// CustomOpWeightOnly: %[[custom:.*]] = "tfl.custom"(%arg0, %[[dq_w:.*]]) <{custom_code = "CustomTestOp", custom_option = #tfl}> // CustomOpWeightOnly: return %[[custom:.*]] -// CustomOpNotWeightOnly: %[[q_w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>> -// CustomOpNotWeightOnly: %[[custom:.*]] = "tfl.custom"(%arg0, %[[q_w:.*]]) {custom_code = "CustomTestOp", custom_option = #tfl} +// CustomOpNotWeightOnly: %[[q_w:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<1024x1x1x1x!quant.uniform:f32, 1.000000e+00>> +// CustomOpNotWeightOnly: %[[custom:.*]] = "tfl.custom"(%arg0, %[[q_w:.*]]) <{custom_code = "CustomTestOp", custom_option = #tfl}> // CustomOpNotWeightOnly: return %[[custom:.*]] } @@ -234,22 +234,22 @@ func.func @NotQuantizeConv3D(%arg0: tensor<1x32x32x32x8xf32>) -> tensor<1x32x32x // CHECK-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<1x1x1x8x16xf32> // CHECK-DAG: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> -// CHECK: %[[conv_3d:.*]] = "tfl.conv_3d"(%arg0, %[[w]], %[[b]]) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} +// CHECK: %[[conv_3d:.*]] = "tfl.conv_3d"(%arg0, %[[w]], %[[b]]) <{dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32}> // CHECK: return %[[conv_3d:.*]] // PerTensor: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<1x1x1x8x16xf32> // PerTensor: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> -// PerTensor: %[[conv_3d:.*]] = "tfl.conv_3d"(%arg0, %[[w]], %[[b]]) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} +// PerTensor: %[[conv_3d:.*]] = "tfl.conv_3d"(%arg0, %[[w]], %[[b]]) <{dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32}> // PerTensor: return %[[conv_3d:.*]] // PerChannelWeightOnly: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<1x1x1x8x16xf32> // PerChannelWeightOnly: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> -// PerChannelWeightOnly: %[[conv_3d:.*]] = "tfl.conv_3d"(%arg0, %[[w]], %[[b]]) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} +// PerChannelWeightOnly: %[[conv_3d:.*]] = "tfl.conv_3d"(%arg0, %[[w]], %[[b]]) <{dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32}> // PerChannelWeightOnly: return %[[conv_3d:.*]] // PerTensorWeightOnly: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<1x1x1x8x16xf32> // PerTensorWeightOnly: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> -// PerTensorWeightOnly: %[[conv_3d:.*]] = "tfl.conv_3d"(%arg0, %[[w]], %[[b]]) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} +// PerTensorWeightOnly: %[[conv_3d:.*]] = "tfl.conv_3d"(%arg0, %[[w]], %[[b]]) <{dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32}> // PerTensorWeightOnly: return %[[conv_3d:.*]] } @@ -266,50 +266,50 @@ func.func @QuantizeMultiUses(%arg0: tensor<1x224x224x3xf32>, %arg1: tensor<3xi32 func.return %bmm, %emb : tensor<1x112x112x112xf32>, tensor<3x3x3x3xf32> // CHECK-DAG: %[[b:.*]] = arith.constant dense<-1.23697901> : tensor<64xf32> -// CHECK-DAG: %[[w1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> +// CHECK-DAG: %[[w1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> // CHECK-DAG: %[[dq_w1:.*]] = "tfl.dequantize"(%[[w1]]) : (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>) -> tensor<64x3x3x3xf32> -// CHECK-DAG: %[[w2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32:3, {1.000000e+00,1.000000e+00,1.000000e+00} -// CHECK-DAG: %[[w3:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00 +// CHECK-DAG: %[[w2:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32:3, {1.000000e+00,1.000000e+00,1.000000e+00}> +// CHECK-DAG: %[[w3:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00 // CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[w3]], %[[b]]) // CHECK: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[w2]], %[[b]]) // CHECK: %[[emb:.*]] = "tfl.gather"(%[[dq_w1]], %arg1) -// CHECK: %[[bmm:.*]] = "tfl.batch_matmul"(%[[conv]], %[[dconv]]) {adj_x = false, adj_y = true +// CHECK: %[[bmm:.*]] = "tfl.batch_matmul"(%[[conv]], %[[dconv]]) <{adj_x = false, adj_y = true // CHECK-NOT: , asymmetric_quantize_inputs = true // CHECK-SAME: } // CHECK: return %[[bmm:.*]], %[[emb:.*]] // PerTensor: %[[b:.*]] = arith.constant dense<-1.23697901> : tensor<64xf32> -// PerTensor: %[[w1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> +// PerTensor: %[[w1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> // PerTensor: %[[dq_w1:.*]] = "tfl.dequantize"(%[[w1]]) : (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>) -> tensor<64x3x3x3xf32> // PerTensor: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[w1]], %[[b]]) // PerTensor: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[w1]], %[[b]]) // PerTensor: %[[emb:.*]] = "tfl.gather"(%[[dq_w1]], %arg1) -// PerTensor: %[[bmm:.*]] = "tfl.batch_matmul"(%[[conv]], %[[dconv]]) {adj_x = false, adj_y = true +// PerTensor: %[[bmm:.*]] = "tfl.batch_matmul"(%[[conv]], %[[dconv]]) <{adj_x = false, adj_y = true // PerTensor-NOT: , asymmetric_quantize_inputs = true // PerTensor-SAME: } // PerTensor: return %[[bmm:.*]], %[[emb:.*]] // PerChannelWeightOnly-DAG: %[[b:.*]] = arith.constant dense<-1.23697901> : tensor<64xf32> -// PerChannelWeightOnly-DAG: %[[w1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> +// PerChannelWeightOnly-DAG: %[[w1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> // PerChannelWeightOnly-DAG: %[[dq_w1:.*]] = "tfl.dequantize"(%[[w1]]) : (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>) -> tensor<64x3x3x3xf32> -// PerChannelWeightOnly-DAG: %[[w2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32:3, {1.000000e+00,1.000000e+00,1.000000e+00} +// PerChannelWeightOnly-DAG: %[[w2:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32:3, {1.000000e+00,1.000000e+00,1.000000e+00} // PerChannelWeightOnly-DAG: %[[dq_w2:.*]] = "tfl.dequantize"(%[[w2]]) : (tensor<64x3x3x3x!quant.uniform:f32:3, {1.000000e+00,1.000000e+00,1.000000e+00}>>) -> tensor<64x3x3x3xf32> -// PerChannelWeightOnly-DAG: %[[w3:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00 +// PerChannelWeightOnly-DAG: %[[w3:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00 // PerChannelWeightOnly-DAG: %[[dq_w3:.*]] = "tfl.dequantize"(%[[w3]]) : (tensor<64x3x3x3x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00 // PerChannelWeightOnly: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w3]], %[[b]]) // PerChannelWeightOnly: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq_w2]], %[[b]]) // PerChannelWeightOnly: %[[emb:.*]] = "tfl.gather"(%[[dq_w1]], %arg1) -// PerChannelWeightOnly: %[[bmm:.*]] = "tfl.batch_matmul"(%[[conv]], %[[dconv]]) {adj_x = false, adj_y = true +// PerChannelWeightOnly: %[[bmm:.*]] = "tfl.batch_matmul"(%[[conv]], %[[dconv]]) <{adj_x = false, adj_y = true // PerChannelWeightOnly-NOT: , asymmetric_quantize_inputs = true // PerChannelWeightOnly-SAME: } // PerChannelWeightOnly: return %[[bmm:.*]], %[[emb:.*]] // BLOCK: %[[b:.*]] = arith.constant dense<-1.23697901> : tensor<64xf32> -// BLOCK: %[[w1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> +// BLOCK: %[[w1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> // BLOCK: %[[dq_w1:.*]] = "tfl.dequantize"(%[[w1]]) : (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>) -> tensor<64x3x3x3xf32> // BLOCK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq_w1]], %[[b]]) // BLOCK: %[[dconv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[w1]], %[[b]]) // BLOCK: %[[emb:.*]] = "tfl.gather"(%[[dq_w1]], %arg1) -// BLOCK: %[[bmm:.*]] = "tfl.batch_matmul"(%[[conv]], %[[dconv]]) {adj_x = false, adj_y = true +// BLOCK: %[[bmm:.*]] = "tfl.batch_matmul"(%[[conv]], %[[dconv]]) <{adj_x = false, adj_y = true // BLOCK: return %[[bmm:.*]], %[[emb:.*]] } diff --git a/tensorflow/compiler/mlir/lite/tests/quantize-numeric-verify.mlir b/tensorflow/compiler/mlir/lite/tests/quantize-numeric-verify.mlir index 7990b3aaf9e151..d043b98aa94899 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize-numeric-verify.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize-numeric-verify.mlir @@ -18,7 +18,7 @@ func.func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> // DEBUG: %[[f_conv:.*]] = "tfl.conv_2d"(%[[act]], %[[wt]], %[[bias]]) // DEBUG: %[[q_conv:.*]] = "tfl.conv_2d" -// DEBUG: "tfl.NumericVerify"(%[[q_conv]], %[[f_conv]]) {log_if_failed = true, tolerance = 5.000000e+00 : f32} +// DEBUG: "tfl.NumericVerify"(%[[q_conv]], %[[f_conv]]) <{log_if_failed = true, tolerance = 5.000000e+00 : f32}> // DEBUG: return %[[q_conv]] : tensor<1x112x112x32x!quant.uniform> } @@ -56,8 +56,8 @@ func.func @QuantizeSplit(%arg: tensor<4x!quant.uniform>, %cst: tens // DEBUG: %[[f_split:.*]]:2 = "tfl.split" // DEBUG: %[[q_split:.*]]:2 = "tfl.split" -// DEBUG: "tfl.NumericVerify"(%[[q_split]]#1, %[[f_split]]#1) {log_if_failed = true, tolerance = 5.000000e+00 : f32} -// DEBUG: "tfl.NumericVerify"(%[[q_split]]#0, %[[f_split]]#0) {log_if_failed = true, tolerance = 5.000000e+00 : f32} +// DEBUG: "tfl.NumericVerify"(%[[q_split]]#1, %[[f_split]]#1) <{log_if_failed = true, tolerance = 5.000000e+00 : f32}> +// DEBUG: "tfl.NumericVerify"(%[[q_split]]#0, %[[f_split]]#0) <{log_if_failed = true, tolerance = 5.000000e+00 : f32}> } // DEBUG-LABEL: NotQuantizePow diff --git a/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir b/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir index 58dfed58a698e7..a5ac48521818ab 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir @@ -8,10 +8,10 @@ func.func @QuantizeReadVariable() -> (tensor<1x2x1x3x!quant.uniform %3 = "tfl.quantize"(%2) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> func.return %3 : tensor<1x2x1x3x!quant.uniform> -// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() {container = "", shared_name = ""} : () -> tensor<*x!tf_type.resource>>> +// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() <{container = "", shared_name = ""}> : () -> tensor<*x!tf_type.resource>>> // CHECK-NEXT: %[[rv:.*]] = "tfl.read_variable"(%[[vh]]) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x1x3x!quant.uniform> // CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[rv]]) : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3xf32> -// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%[[dq]]) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> +// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%[[dq]]) <{qtype = tensor<1x2x1x3x!quant.uniform>}> {volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> // CHECK-NEXT: return %[[q]] : tensor<1x2x1x3x!quant.uniform> } @@ -22,7 +22,7 @@ func.func @QuantizeAssignVariableWithDequantAndEqualType(%arg0 : tensor<1x2x1x3x "tfl.assign_variable"(%0, %1) : (tensor, tensor<1x2x1x3xf32>) -> () func.return %arg0 : tensor<1x2x1x3x!quant.uniform> -// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() {container = "", shared_name = ""} : () -> tensor<*x!tf_type.resource>>> +// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() <{container = "", shared_name = ""}> : () -> tensor<*x!tf_type.resource>>> // CHECK-NEXT: "tfl.assign_variable"(%[[vh]], %arg0) : (tensor<*x!tf_type.resource>>>, tensor<1x2x1x3x!quant.uniform>) -> () // CHECK-NEXT: return %arg0 : tensor<1x2x1x3x!quant.uniform> } @@ -36,11 +36,11 @@ func.func @QuantizeAssignVariableWithDequantAndNotEqualType(%arg0 : tensor<1x2x1 "tfl.assign_variable"(%1, %5) : (tensor, tensor<1x2x1x3xf32>) -> () func.return %arg0 : tensor<1x2x1x3x!quant.uniform> -// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() {container = "", shared_name = ""} : () -> tensor<*x!tf_type.resource>>> +// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() <{container = "", shared_name = ""}> : () -> tensor<*x!tf_type.resource>>> // CHECK-NEXT: %[[rv:.*]] = "tfl.read_variable"(%[[vh]]) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x1x3x!quant.uniform> // CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[rv]]) : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3xf32> -// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%[[dq]]) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x1x3x!quant.uniform>} : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3x!quant.uniform> +// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%[[dq]]) <{qtype = tensor<1x2x1x3x!quant.uniform>}> {volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x1x3x!quant.uniform>}> : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3x!quant.uniform> // CHECK-NEXT: "tfl.assign_variable"(%[[vh]], %[[q2]]) : (tensor<*x!tf_type.resource>>>, tensor<1x2x1x3x!quant.uniform>) -> () // CHECK-NEXT: return %arg0 : tensor<1x2x1x3x!quant.uniform> } @@ -54,10 +54,10 @@ func.func @QuantizeAssignVariableWithoutDequant(%arg0 : tensor<1x2x1x3xf32>) -> "tfl.assign_variable"(%0, %3) : (tensor, tensor<1x2x1x3xf32>) -> () func.return %arg0 : tensor<1x2x1x3xf32> -// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() {container = "", shared_name = ""} : () -> tensor<*x!tf_type.resource>>> +// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() <{container = "", shared_name = ""}> : () -> tensor<*x!tf_type.resource>>> // CHECK-NEXT: %[[rv:.*]] = "tfl.read_variable"(%[[vh]]) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x1x3x!quant.uniform> // CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[rv]]) : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3xf32> -// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%[[dq]]) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> +// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%[[dq]]) <{qtype = tensor<1x2x1x3x!quant.uniform>}> {volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> // CHECK-NEXT: "tfl.assign_variable"(%[[vh]], %[[q]]) : (tensor<*x!tf_type.resource>>>, tensor<1x2x1x3x!quant.uniform>) -> () // CHECK-NEXT: return %arg0 : tensor<1x2x1x3xf32> } @@ -67,7 +67,7 @@ func.func @VarHandleCase(%arg0 : tensor<1x2x1x3xf32>) -> tensor<1x2x1x3xf32> { %0 = "tfl.var_handle"() : () -> tensor func.return %arg0 : tensor<1x2x1x3xf32> -// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() {container = "", shared_name = ""} : () -> tensor<*x!tf_type.resource>>> +// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() <{container = "", shared_name = ""}> : () -> tensor<*x!tf_type.resource>>> // CHECK-NEXT: return %arg0 : tensor<1x2x1x3xf32> } @@ -89,19 +89,19 @@ func.func @QuantizeReadAssign(%arg0: tensor<1x32x1x3xf32>) -> (tensor<1x34x1x3xf "tfl.assign_variable"(%2, %9) : (tensor, tensor<1x2x1x3xf32>) -> () func.return %6 : tensor<1x34x1x3xf32> -// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x32x1x3x!quant.uniform>, volatile} : (tensor<1x32x1x3xf32>) -> tensor<1x32x1x3x!quant.uniform> +// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x32x1x3x!quant.uniform>}> {volatile} : (tensor<1x32x1x3xf32>) -> tensor<1x32x1x3x!quant.uniform> // CHECK-NEXT: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) : (tensor<1x32x1x3x!quant.uniform>) -> tensor<1x32x1x3xf32> // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1> : tensor<4xi32> // CHECK-NEXT: %[[cst_0:.*]] = arith.constant dense<[0, 0, 0, 3]> : tensor<4xi32> // CHECK-NEXT: %[[cst_1:.*]] = arith.constant dense<[0, -2, 0, 0]> : tensor<4xi32> -// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() {container = "", shared_name = "read_assign2/states"} : () -> tensor<*x!tf_type.resource>>> +// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() <{container = "", shared_name = "read_assign2/states"}> : () -> tensor<*x!tf_type.resource>>> // CHECK-NEXT: %[[rv:.*]] = "tfl.read_variable"(%[[vh]]) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x1x3x!quant.uniform> // CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[rv]]) : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3xf32> -// CHECK-NEXT: %[[cc:.*]] = "tfl.concatenation"(%[[dq2]], %[[dq1]]) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x1x3xf32>, tensor<1x32x1x3xf32>) -> tensor<1x34x1x3xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cc]]) {qtype = tensor<1x34x1x3x!quant.uniform>, volatile} : (tensor<1x34x1x3xf32>) -> tensor<1x34x1x3x!quant.uniform> +// CHECK-NEXT: %[[cc:.*]] = "tfl.concatenation"(%[[dq2]], %[[dq1]]) <{axis = 1 : i32, fused_activation_function = "NONE"}> : (tensor<1x2x1x3xf32>, tensor<1x32x1x3xf32>) -> tensor<1x34x1x3xf32> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cc]]) <{qtype = tensor<1x34x1x3x!quant.uniform>}> {volatile} : (tensor<1x34x1x3xf32>) -> tensor<1x34x1x3x!quant.uniform> // CHECK-NEXT: %[[dq3:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x34x1x3x!quant.uniform>) -> tensor<1x34x1x3xf32> -// CHECK-NEXT: %[[ss:.*]] = "tfl.strided_slice"(%[[dq3]], %[[cst_1]], %[[cst_0]], %[[cst]]) {begin_mask = 13 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<1x34x1x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x1x3xf32> -// CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%[[ss]]) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> +// CHECK-NEXT: %[[ss:.*]] = "tfl.strided_slice"(%[[dq3]], %[[cst_1]], %[[cst_0]], %[[cst]]) <{begin_mask = 13 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<1x34x1x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x1x3xf32> +// CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%[[ss]]) <{qtype = tensor<1x2x1x3x!quant.uniform>}> {volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> // CHECK-NEXT: "tfl.assign_variable"(%[[vh]], %[[q3]]) : (tensor<*x!tf_type.resource>>>, tensor<1x2x1x3x!quant.uniform>) -> () // CHECK-NEXT: return %[[dq3]] : tensor<1x34x1x3xf32> } @@ -133,11 +133,11 @@ func.func @QuantizeConvVariable(%arg0: tensor<1x3x1x1xf32>) -> (tensor<1x3x1x1xf "tfl.assign_variable"(%6, %16) : (tensor, tensor<1x3x1x1xf32>) -> () func.return %10 : tensor<1x3x1x1xf32> -// WHOLE-PASSES: %[[vh:.*]] = "tfl.var_handle"() {container = "", shared_name = "conv_variable/state"} : () -> tensor<*x!tf_type.resource>>> +// WHOLE-PASSES: %[[vh:.*]] = "tfl.var_handle"() <{container = "", shared_name = "conv_variable/state"}> : () -> tensor<*x!tf_type.resource>>> // WHOLE-PASSES-NEXT: %[[rv:.*]] = "tfl.read_variable"(%[[vh]]) : (tensor<*x!tf_type.resource>>>) -> tensor<1x3x1x1x!quant.uniform> -// WHOLE-PASSES-DAG: %[[cv:.*]] = "tfl.conv_2d"(%arg0, {{.*}}) {{{.*}}} : (tensor<1x3x1x1x!quant.uniform>, tensor<1x3x1x1x!quant.uniform:f32:0, {{.*}}>>, tensor<1x!quant.uniform>) -> tensor<1x3x1x1x!quant.uniform> -// WHOLE-PASSES-NEXT: %[[cc:.*]] = "tfl.concatenation"(%[[rv]], %[[cv]]) {{{.*}}} : (tensor<1x3x1x1x!quant.uniform>, tensor<1x3x1x1x!quant.uniform>) -> tensor<1x6x1x1x!quant.uniform> -// WHOLE-PASSES-NEXT: %[[ss:.*]] = "tfl.strided_slice"(%[[cc]], {{.*}}) {{{.*}}} : (tensor<1x6x1x1x!quant.uniform>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x3x1x1x!quant.uniform> +// WHOLE-PASSES-DAG: %[[cv:.*]] = "tfl.conv_2d"(%arg0, {{.*}}) <{{{.*}}}> : (tensor<1x3x1x1x!quant.uniform>, tensor<1x3x1x1x!quant.uniform:f32:0, {{.*}}>>, tensor<1x!quant.uniform>) -> tensor<1x3x1x1x!quant.uniform> +// WHOLE-PASSES-NEXT: %[[cc:.*]] = "tfl.concatenation"(%[[rv]], %[[cv]]) <{{{.*}}}> : (tensor<1x3x1x1x!quant.uniform>, tensor<1x3x1x1x!quant.uniform>) -> tensor<1x6x1x1x!quant.uniform> +// WHOLE-PASSES-NEXT: %[[ss:.*]] = "tfl.strided_slice"(%[[cc]], {{.*}}) <{{{.*}}}> : (tensor<1x6x1x1x!quant.uniform>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x3x1x1x!quant.uniform> // WHOLE-PASSES-NEXT: "tfl.assign_variable"(%[[vh]], %[[ss]]) : (tensor<*x!tf_type.resource>>>, tensor<1x3x1x1x!quant.uniform>) -> () // WHOLE-PASSES-NEXT: return %[[cv]] : tensor<1x3x1x1x!quant.uniform> } @@ -171,19 +171,19 @@ func.func @QuantizeTwoVariable(%arg0: tensor<1x2x3xf32>) -> (tensor<1x2x3xf32>) func.return %0 : tensor<1x2x3xf32> -// WHOLE-PASSES: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x3x!quant.uniform>} : (tensor<1x2x3x!quant.uniform>) -> tensor<1x2x3x!quant.uniform> -// WHOLE-PASSES-DAG: %[[vh1:.*]] = "tfl.var_handle"() {container = "", shared_name = "read_assign/states0"} : () -> tensor<*x!tf_type.resource>>> -// WHOLE-PASSES-DAG: %[[vh2:.*]] = "tfl.var_handle"() {container = "", shared_name = "read_assign/states1"} : () -> tensor<*x!tf_type.resource>>> +// WHOLE-PASSES: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x3x!quant.uniform>}> : (tensor<1x2x3x!quant.uniform>) -> tensor<1x2x3x!quant.uniform> +// WHOLE-PASSES-DAG: %[[vh1:.*]] = "tfl.var_handle"() <{container = "", shared_name = "read_assign/states0"}> : () -> tensor<*x!tf_type.resource>>> +// WHOLE-PASSES-DAG: %[[vh2:.*]] = "tfl.var_handle"() <{container = "", shared_name = "read_assign/states1"}> : () -> tensor<*x!tf_type.resource>>> // WHOLE-PASSES-DAG: %[[rv1:.*]] = "tfl.read_variable"({{.*}}) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x3x!quant.uniform> // WHOLE-PASSES-NEXT: %[[cc1:.*]] = "tfl.concatenation"(%[[rv1]], {{.*}}) {{.*}} : (tensor<1x2x3x!quant.uniform>, tensor<1x2x3x!quant.uniform>) -> tensor<1x4x3x!quant.uniform> -// WHOLE-PASSES-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cc1]]) {qtype = tensor<1x4x3x!quant.uniform>} : (tensor<1x4x3x!quant.uniform>) -> tensor<1x4x3x!quant.uniform> -// WHOLE-PASSES-NEXT: %[[ss1:.*]] = "tfl.strided_slice"(%[[q2]], {{.*}}) {{{.*}}} : (tensor<1x4x3x!quant.uniform>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x2x3x!quant.uniform> +// WHOLE-PASSES-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cc1]]) <{qtype = tensor<1x4x3x!quant.uniform>}> : (tensor<1x4x3x!quant.uniform>) -> tensor<1x4x3x!quant.uniform> +// WHOLE-PASSES-NEXT: %[[ss1:.*]] = "tfl.strided_slice"(%[[q2]], {{.*}}) <{{{.*}}}> : (tensor<1x4x3x!quant.uniform>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x2x3x!quant.uniform> // WHOLE-PASSES-NEXT: "tfl.assign_variable"(%[[vh1]], %[[ss1]]) : (tensor<*x!tf_type.resource>>>, tensor<1x2x3x!quant.uniform>) -> () // WHOLE-PASSES-DAG: %[[rv2:.*]] = "tfl.read_variable"({{.*}}) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x3x!quant.uniform> // WHOLE-PASSES-NEXT: %[[cc2:.*]] = "tfl.concatenation"(%[[rv2]], {{.*}}) {{.*}} : (tensor<1x2x3x!quant.uniform>, tensor<1x2x3x!quant.uniform>) -> tensor<1x4x3x!quant.uniform> -// WHOLE-PASSES-NEXT: %[[ss2:.*]] = "tfl.strided_slice"(%[[cc2]], {{.*}}) {{{.*}}} : (tensor<1x4x3x!quant.uniform>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x2x3x!quant.uniform> +// WHOLE-PASSES-NEXT: %[[ss2:.*]] = "tfl.strided_slice"(%[[cc2]], {{.*}}) <{{{.*}}}> : (tensor<1x4x3x!quant.uniform>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x2x3x!quant.uniform> // WHOLE-PASSES-NEXT: "tfl.assign_variable"(%[[vh2]], %[[ss2]]) : (tensor<*x!tf_type.resource>>>, tensor<1x2x3x!quant.uniform>) -> () // WHOLE-PASSES-NEXT: return %arg0 : tensor<1x2x3x!quant.uniform> diff --git a/tensorflow/compiler/mlir/lite/tests/quantize.mlir b/tensorflow/compiler/mlir/lite/tests/quantize.mlir index 69771e0100f496..f99d3cb409f0fc 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize.mlir @@ -8,7 +8,7 @@ func.func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %1 : tensor<2x2x!quant.uniform> -// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform>, value = dense<0> : tensor<2x2xi8>} +// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x2x!quant.uniform>, value = dense<0> : tensor<2x2xi8>}> // CHECK: return %[[cst]] } @@ -18,7 +18,7 @@ func.func @QuantizeFloatConst4Bits() -> tensor<2x4x!quant.uniform>} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform> func.return %1 : tensor<2x4x!quant.uniform> -// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x4x!quant.uniform>, value = dense<{{\[\[}}-4, -3, -2, -1{{\]}}, [0, 1, 2, 3{{\]\]}}> : tensor<2x4xi4>} +// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x4x!quant.uniform>, value = dense<{{\[\[}}-4, -3, -2, -1{{\]}}, [0, 1, 2, 3{{\]\]}}> : tensor<2x4xi4>}> // CHECK: return %[[cst]] } @@ -28,7 +28,7 @@ func.func @QuantizeDenseFloatConst() -> tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %1 : tensor<2x2x!quant.uniform> -// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform>, value = dense<{{\[\[}}0, -1], {{\[}}-1, -1]]> : tensor<2x2xi8>} +// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x2x!quant.uniform>, value = dense<{{\[\[}}0, -1], {{\[}}-1, -1]]> : tensor<2x2xi8>}> // CHECK: return %[[cst]] } @@ -38,7 +38,7 @@ func.func @QuantizeSplatFloatConst() -> tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %1 : tensor<2x2x!quant.uniform> -// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform>, value = dense<-1> : tensor<2x2xi8>} +// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x2x!quant.uniform>, value = dense<-1> : tensor<2x2xi8>}> // CHECK: return %[[cst]] } @@ -60,7 +60,7 @@ func.func @DequantizeAndQuantize() -> tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %1 : tensor<2x2x!quant.uniform> -// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform>, value = dense<-1> : tensor<2x2xi8>} +// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x2x!quant.uniform>, value = dense<-1> : tensor<2x2xi8>}> // CHECK: return %[[cst]] : tensor<2x2x!quant.uniform> } @@ -76,8 +76,8 @@ func.func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> func.return %6 : tensor<1x112x112x32x!quant.uniform> -// CHECK: %[[cst0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<-1583> : tensor<32xi32>} -// CHECK: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 1.000000e-01>>, value = dense<1> : tensor<32x3x3x3xi8>} +// CHECK: %[[cst0:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform>, value = dense<-1583> : tensor<32xi32>}> +// CHECK: %[[cst1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x3x3x3x!quant.uniform:f32, 1.000000e-01>>, value = dense<1> : tensor<32x3x3x3xi8>}> // CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[cst1]], %[[cst0]]) // CHECK: return %[[conv]] : tensor<1x112x112x32x!quant.uniform> } @@ -94,8 +94,8 @@ func.func @QuantizeConv2D4Bit(tensor<1x224x224x3x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> func.return %6 : tensor<1x112x112x32x!quant.uniform> -// CHECK: %[[cst0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<-1583> : tensor<32xi32>} -// CHECK: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 1.000000e-01>>, value = dense<1> : tensor<32x3x3x3xi4>} +// CHECK: %[[cst0:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform>, value = dense<-1583> : tensor<32xi32>}> +// CHECK: %[[cst1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x3x3x3x!quant.uniform:f32, 1.000000e-01>>, value = dense<1> : tensor<32x3x3x3xi4>}> // CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[cst1]], %[[cst0]]) // CHECK: return %[[conv]] : tensor<1x112x112x32x!quant.uniform> } @@ -111,9 +111,9 @@ func.func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> func.return %6 : tensor<1x112x112x32x!quant.uniform> -// CHECK: %[[cst0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<-7254> : tensor<32xi32>} -// CHECK: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} -// CHECK: %[[conv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[cst1]], %[[cst0]]) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} +// CHECK: %[[cst0:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform>, value = dense<-7254> : tensor<32xi32>}> +// CHECK: %[[cst1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}> +// CHECK: %[[conv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[cst1]], %[[cst0]]) <{depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32}> // CHECK: return %[[conv]] } @@ -128,9 +128,9 @@ func.func @QuantizeDepthwiseConv2D4Bit(tensor<1x224x224x3x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> func.return %6 : tensor<1x112x112x32x!quant.uniform> -// CHECK: %[[cst0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<-400> : tensor<32xi32>} -// CHECK: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.39599830800000002:8>>, value = dense<-7> : tensor<32x3x3x3xi4>} -// CHECK: %[[conv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[cst1]], %[[cst0]]) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} +// CHECK: %[[cst0:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform>, value = dense<-400> : tensor<32xi32>}> +// CHECK: %[[cst1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.39599830800000002:8>>, value = dense<-7> : tensor<32x3x3x3xi4>}> +// CHECK: %[[conv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[cst1]], %[[cst0]]) <{depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32}> // CHECK: return %[[conv]] } @@ -145,14 +145,14 @@ func.func @QuantizeFullyConnected(tensor<1x224x224x3x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> func.return %6 : tensor<1x112x112x32x!quant.uniform> -// CHECK: %[[cst_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<-7254> : tensor<32xi32>} -// CHECK: %[[cst_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x12x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x12xi8>} -// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[cst_1]], %[[cst_0]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} +// CHECK: %[[cst_0:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform>, value = dense<-7254> : tensor<32xi32>}> +// CHECK: %[[cst_1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x12x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x12xi8>}> +// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[cst_1]], %[[cst_0]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> // CHECK: return %[[fc]] // BLOCK: %[[cst:.*]] = "tfl.pseudo_const"(){{.*}}dense<-1.23697901> // BLOCK: %[[dq1:.*]] = "tfl.dequantize"(%arg0) -// BLOCK: %[[cst2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x12x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x12xi8>} +// BLOCK: %[[cst2:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x12x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x12xi8>}> // BLOCK: %[[dq2:.*]] = "tfl.dequantize"(%[[cst2]]) // BLOCK: %[[fc:.*]] = "tfl.fully_connected"(%[[dq1]], %[[dq2]], %[[cst]]) // BLOCK: %[[q:.*]] = "tfl.quantize"(%[[fc]]) @@ -170,14 +170,14 @@ func.func @QuantizeFullyConnected4Bit(tensor<1x224x224x3x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> func.return %6 : tensor<1x112x112x32x!quant.uniform> -// CHECK: %[[cst_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<-400> : tensor<32xi32>} -// CHECK: %[[cst_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x12x!quant.uniform:f32, 0.39599830800000002:8>>, value = dense<-7> : tensor<32x12xi4>} -// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[cst_1]], %[[cst_0]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} +// CHECK: %[[cst_0:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x!quant.uniform>, value = dense<-400> : tensor<32xi32>}> +// CHECK: %[[cst_1:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x12x!quant.uniform:f32, 0.39599830800000002:8>>, value = dense<-7> : tensor<32x12xi4>}> +// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[cst_1]], %[[cst_0]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> // CHECK: return %[[fc]] // BLOCK: %[[cst:.*]] = "tfl.pseudo_const"(){{.*}}dense<-1.23697901> // BLOCK: %[[dq1:.*]] = "tfl.dequantize"(%arg0) -// BLOCK: %[[cst2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x12x!quant.uniform:f32, 0.39599830800000002:8>>, value = dense<-7> : tensor<32x12xi4>} +// BLOCK: %[[cst2:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<32x12x!quant.uniform:f32, 0.39599830800000002:8>>, value = dense<-7> : tensor<32x12xi4>}> // BLOCK: %[[dq2:.*]] = "tfl.dequantize"(%[[cst2]]) // BLOCK: %[[fc:.*]] = "tfl.fully_connected"(%[[dq1]], %[[dq2]], %[[cst]]) // BLOCK: %[[q:.*]] = "tfl.quantize"(%[[fc]]) @@ -254,13 +254,13 @@ func.func @QuantizeAdd(tensor<1x56x56x24x!quant.uniform>} : (tensor<1x56x56x24xf32>) -> tensor<1x56x56x24x!quant.uniform> func.return %3 : tensor<1x56x56x24x!quant.uniform> -// CHECK: %[[add:.*]] = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x56x56x24x!quant.uniform>, tensor<1x56x56x24x!quant.uniform>) +// CHECK: %[[add:.*]] = tfl.add(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x56x56x24x!quant.uniform>, tensor<1x56x56x24x!quant.uniform>) // CHECK: return %[[add]] : tensor<1x56x56x24x!quant.uniform> // BLOCK: %[[dq0:.*]] = "tfl.dequantize"(%arg0) : (tensor<1x56x56x24x!quant.uniform>) // BLOCK: %[[dq1:.*]] = "tfl.dequantize"(%arg1) : (tensor<1x56x56x24x!quant.uniform>) // BLOCK: %[[add:.*]] = tfl.add %[[dq0]], %[[dq1]] {fused_activation_function = "NONE"} : tensor<1x56x56x24xf32> -// BLOCK: %[[q:.*]] = "tfl.quantize"(%[[add]]) {qtype = tensor<1x56x56x24x!quant.uniform>} +// BLOCK: %[[q:.*]] = "tfl.quantize"(%[[add]]) <{qtype = tensor<1x56x56x24x!quant.uniform>}> // BLOCK: return %[[q]] : tensor<1x56x56x24x!quant.uniform> } @@ -271,9 +271,9 @@ func.func @QuantizeConcat(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant %1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %1 : tensor<2x2x!quant.uniform> -// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>, volatile} -// CHECK: %[[cc:.*]] = "tfl.concatenation"(%[[q1]], %[[q0]]) {axis = 0 : i32, fused_activation_function = "NONE"} +// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} +// CHECK: %[[cc:.*]] = "tfl.concatenation"(%[[q1]], %[[q0]]) <{axis = 0 : i32, fused_activation_function = "NONE"}> // CHECK: return %[[cc]] : tensor<2x2x!quant.uniform> } @@ -285,9 +285,9 @@ func.func @QuantizeConcatRequantize(tensor<1x2x!quant.uniform>, %3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %3 : tensor<2x2x!quant.uniform> -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} -// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} -// CHECK: %[[cc:.*]] = "tfl.concatenation"(%[[q0]], %[[q1]]) {axis = 0 : i32, fused_activation_function = "NONE"} +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} +// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x!quant.uniform>}> +// CHECK: %[[cc:.*]] = "tfl.concatenation"(%[[q0]], %[[q1]]) <{axis = 0 : i32, fused_activation_function = "NONE"}> // CHECK: return %[[cc]] : tensor<2x2x!quant.uniform> } @@ -298,7 +298,7 @@ func.func @QuantizeMaxPool2D(tensor<1x6x6x16x!quant.uniform) -> tensor<1x1x1x16xf32> func.return %1 : tensor<1x1x1x16xf32> -// CHECK: %[[mp:.*]] = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x1x1x16x!quant.uniform> +// CHECK: %[[mp:.*]] = "tfl.max_pool_2d"(%arg0) <{filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x1x1x16x!quant.uniform> // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[mp]]) : (tensor<1x1x1x16x!quant.uniform>) -> tensor<1x1x1x16xf32> // CHECK: return %[[dq]] : tensor<1x1x1x16xf32> } @@ -311,7 +311,7 @@ func.func @QuantizeSplit(%arg: tensor<4x!quant.uniform>, %cst: tens %3 = "tfl.quantize"(%1#1) {qtype = tensor<2x!quant.uniform>} : (tensor<2xf32>) -> tensor<2x!quant.uniform> func.return %2, %3 : tensor<2x!quant.uniform>, tensor<2x!quant.uniform> -// CHECK: %[[sp:.*]]:2 = "tfl.split"(%arg1, %arg0) {num_splits = 2 : i32} : (tensor, tensor<4x!quant.uniform>) +// CHECK: %[[sp:.*]]:2 = "tfl.split"(%arg1, %arg0) <{num_splits = 2 : i32}> : (tensor, tensor<4x!quant.uniform>) // CHECK: return %[[sp]]#0, %[[sp]]#1 } @@ -324,7 +324,7 @@ func.func @QuantizeSplitUnusedResults(%arg: tensor<4x!quant.uniform %3 = "tfl.quantize"(%1#1) {qtype = tensor<2x!quant.uniform>} : (tensor<2xf32>) -> tensor<2x!quant.uniform> func.return %2, %3 : tensor<2x!quant.uniform>, tensor<2x!quant.uniform> -// CHECK: %[[sp:.*]]:4 = "tfl.split"(%arg1, %arg0) {num_splits = 4 : i32} : (tensor, tensor<4x!quant.uniform>) +// CHECK: %[[sp:.*]]:4 = "tfl.split"(%arg1, %arg0) <{num_splits = 4 : i32}> : (tensor, tensor<4x!quant.uniform>) // CHECK: return %[[sp]]#0, %[[sp]]#1 } @@ -440,7 +440,7 @@ func.func @CheckLegacyQuantizeAdd() -> tensor<1x2x!quant.uniform>, volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> func.return %0 : tensor<1x2x!quant.uniform> -// LEGACY: "tfl.pseudo_qconst"() {qtype = tensor<1x2x!quant.uniform>, value = dense<{{\[\[}}-1, 127]]> : tensor<1x2xi8>} +// LEGACY: "tfl.pseudo_qconst"() <{qtype = tensor<1x2x!quant.uniform>, value = dense<{{\[\[}}-1, 127]]> : tensor<1x2xi8>}> } func.func private @testIfThen(tensor<*xf32>) -> tensor<*xf32> @@ -467,8 +467,8 @@ func.func @NotQuantizeReadVariable() -> tensor<1x2x3x!quant.uniform:f3 %1 = "tfl.read_variable"(%0) : (tensor>>) -> tensor<1x2x3xf32> %2 = "tfl.quantize"(%1) {qtype = tensor<1x2x3x!quant.uniform:f32, 0.047244094488188976:128>>} : (tensor<1x2x3xf32>) -> tensor<1x2x3x!quant.uniform:f32, 0.047244094488188976:128>> func.return %2 : tensor<1x2x3x!quant.uniform:f32, 0.047244094488188976:128>> - // CHECK: %[[handle:.*]] = "tfl.var_handle"() {container = "", shared_name = "states"} : () -> tensor>> + // CHECK: %[[handle:.*]] = "tfl.var_handle"() <{container = "", shared_name = "states"}> : () -> tensor>> // CHECK-NEXT: %[[read:.*]] = "tfl.read_variable"(%[[handle]]) : (tensor>>) -> tensor<1x2x3xf32> - // CHECK-NEXT: %[[quantize:.*]] = "tfl.quantize"(%[[read]]) {qtype = tensor<1x2x3x!quant.uniform:f32, 0.047244094488188976:128>>} : (tensor<1x2x3xf32>) -> tensor<1x2x3x!quant.uniform:f32, 0.047244094488188976:128>> + // CHECK-NEXT: %[[quantize:.*]] = "tfl.quantize"(%[[read]]) <{qtype = tensor<1x2x3x!quant.uniform:f32, 0.047244094488188976:128>>}> : (tensor<1x2x3xf32>) -> tensor<1x2x3x!quant.uniform:f32, 0.047244094488188976:128>> // CHECK-NEXT: return %[[quantize]] } diff --git a/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir b/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir index 5baa981122985c..01a4a72749bf45 100644 --- a/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir +++ b/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir @@ -3,7 +3,7 @@ module attributes {tf.versions = {producer = 888 : i32}} { // CHECK-LABEL: testConv2dShapeValidPadding func.func @testConv2dShapeValidPadding(%arg0: tensor<1x112x80x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> { - // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x108x76x128xf32> + // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) <{dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x108x76x128xf32> %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32> func.return %0 : tensor<1x?x?x128xf32> } @@ -14,7 +14,7 @@ func.func @testConv2dShapeValidPadding(%arg0: tensor<1x112x80x128xf32>, %arg1: t module attributes {tf.versions = {producer = 888 : i32}} { // CHECK-LABEL: testConv2dShapeInferenceSamePadding func.func @testConv2dShapeInferenceSamePadding(%arg0: tensor<1x112x80x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> { - // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x112x80x128xf32> + // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x112x80x128xf32> %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32> func.return %0 : tensor<1x?x?x128xf32> } @@ -25,7 +25,7 @@ func.func @testConv2dShapeInferenceSamePadding(%arg0: tensor<1x112x80x128xf32>, module attributes {tf.versions = {producer = 888 : i32}} { // CHECK-LABEL: testConv2dShapeInferenceDilation func.func @testConv2dShapeInferenceDilation(%arg0: tensor<1x112x80x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> { - // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x112x80x128xf32> + // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) <{dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x112x80x128xf32> %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32> func.return %0 : tensor<1x?x?x128xf32> } @@ -36,7 +36,7 @@ func.func @testConv2dShapeInferenceDilation(%arg0: tensor<1x112x80x128xf32>, %ar module attributes {tf.versions = {producer = 888 : i32}} { // CHECK-LABEL: testConv2dShapeInferenceStrides func.func @testConv2dShapeInferenceStrides(%arg0: tensor<1x112x80x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> { - // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x56x40x128xf32> + // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x56x40x128xf32> %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32> func.return %0 : tensor<1x?x?x128xf32> } @@ -47,7 +47,7 @@ func.func @testConv2dShapeInferenceStrides(%arg0: tensor<1x112x80x128xf32>, %arg module attributes {tf.versions = {producer = 888 : i32}} { // CHECK-LABEL: testConv2dShapeInferenceUnranked func.func @testConv2dShapeInferenceUnranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } @@ -58,7 +58,7 @@ func.func @testConv2dShapeInferenceUnranked(%arg0: tensor<*xf32>, %arg1: tensor< module attributes {tf.versions = {producer = 888 : i32}} { // CHECK-LABEL: testConv2dShapeInferenceDynamic func.func @testConv2dShapeInferenceDynamic(%arg0: tensor<1x?x?x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> { - // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x?x?x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32> + // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) <{dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x?x?x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32> %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x?x?x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32> func.return %0 : tensor<1x?x?x128xf32> } @@ -80,7 +80,7 @@ func.func @testConv2dShapeInvalidRanks(%arg0: tensor<1x112x80xf32>, %arg1: tenso module attributes {tf.versions = {producer = 888 : i32}} { // CHECK-LABEL: testUnidirectionalSequenceLstmShapeInference func.func @testUnidirectionalSequenceLstmShapeInference(%arg0: tensor<600 x 10 x 20 x f32>, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor<40 x f32>, %arg16: tensor, %arg17: tensor, %arg18: tensor<600 x 40 x f32>, %arg19: tensor<600 x 40 x f32>, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<600x10x20xf32>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<40xf32>, tensor, tensor, tensor<600x40xf32>, tensor<600x40xf32>, tensor, tensor, tensor, tensor) -> tensor<600x10x40xf32 + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) <{fused_activation_function = "NONE", time_major = false}> : (tensor<600x10x20xf32>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<40xf32>, tensor, tensor, tensor<600x40xf32>, tensor<600x40xf32>, tensor, tensor, tensor, tensor) -> tensor<600x10x40xf32 %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<600 x 10 x 20 x f32>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<40xf32>, tensor, tensor, tensor<600x40xf32>, tensor<600x40xf32>, tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -91,7 +91,7 @@ func.func @testUnidirectionalSequenceLstmShapeInference(%arg0: tensor<600 x 10 x module attributes {tf.versions = {producer = 888 : i32}} { // CHECK-LABEL: testUnidirectionalSequenceLstmShapeInference func.func @testUnidirectionalSequenceLstmShapeInference(%arg0: tensor<600 x ? x 20 x f32>, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor<40 x f32>, %arg16: tensor, %arg17: tensor, %arg18: tensor<600 x 40 x f32>, %arg19: tensor<600 x 40 x f32>, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<600x?x20xf32>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<40xf32>, tensor, tensor, tensor<600x40xf32>, tensor<600x40xf32>, tensor, tensor, tensor, tensor) -> tensor<600x?x40xf32 + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) <{fused_activation_function = "NONE", time_major = false}> : (tensor<600x?x20xf32>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<40xf32>, tensor, tensor, tensor<600x40xf32>, tensor<600x40xf32>, tensor, tensor, tensor, tensor) -> tensor<600x?x40xf32 %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<600 x ? x 20 x f32>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<40xf32>, tensor, tensor, tensor<600x40xf32>, tensor<600x40xf32>, tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor } diff --git a/tensorflow/compiler/mlir/lite/tests/split-merged-operands.mlir b/tensorflow/compiler/mlir/lite/tests/split-merged-operands.mlir index 134b7a0dccf7f6..ffe1ee7264e8d1 100644 --- a/tensorflow/compiler/mlir/lite/tests/split-merged-operands.mlir +++ b/tensorflow/compiler/mlir/lite/tests/split-merged-operands.mlir @@ -2,9 +2,9 @@ func.func @testSingleLstm(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> { // CHECK-LABEL: testSingleLstm - // CHECK-DAG: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> - // CHECK-DAG: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> - // CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg2, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4x4xf32> + // CHECK-DAG: %[[CST_0:.*]] = "tfl.pseudo_const"() <{value = dense<0.000000e+00> : tensor<4x4xf32>}> : () -> tensor<4x4xf32> + // CHECK-DAG: %[[CST_1:.*]] = "tfl.pseudo_const"() <{value = dense<0.000000e+00> : tensor<4x4xf32>}> : () -> tensor<4x4xf32> + // CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg2, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) <{fused_activation_function = "NONE", time_major = true}> : (tensor<4x4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4x4xf32> %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const") %1 = "tfl.unidirectional_sequence_lstm"(%arg2, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4x4xf32> @@ -13,15 +13,29 @@ func.func @testSingleLstm(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>, %arg2: t func.func @testMultipleLstms(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> { // CHECK-LABEL: testMultipleLstms - // CHECK-DAG: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> - // CHECK-DAG: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> - // CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg2, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4x4xf32> - // CHECK-DAG: %[[CST_2:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> - // CHECK-DAG: %[[CST_3:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> - // CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4x4xf32> + // CHECK-DAG: %[[CST_0:.*]] = "tfl.pseudo_const"() <{value = dense<0.000000e+00> : tensor<4x4xf32>}> : () -> tensor<4x4xf32> + // CHECK-DAG: %[[CST_1:.*]] = "tfl.pseudo_const"() <{value = dense<0.000000e+00> : tensor<4x4xf32>}> : () -> tensor<4x4xf32> + // CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg2, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) <{fused_activation_function = "NONE", time_major = true}> : (tensor<4x4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4x4xf32> + // CHECK-DAG: %[[CST_2:.*]] = "tfl.pseudo_const"() <{value = dense<0.000000e+00> : tensor<4x4xf32>}> : () -> tensor<4x4xf32> + // CHECK-DAG: %[[CST_3:.*]] = "tfl.pseudo_const"() <{value = dense<0.000000e+00> : tensor<4x4xf32>}> : () -> tensor<4x4xf32> + // CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) <{fused_activation_function = "NONE", time_major = true}> : (tensor<4x4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4x4xf32> %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const") %1 = "tfl.unidirectional_sequence_lstm"(%arg2, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4x4xf32> %2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4x4xf32> func.return %2 : tensor<4x4x4xf32> } + +func.func @testSingleLstmFloat16(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> { + // CHECK-LABEL: testSingleLstm + // CHECK-DAG: %[[CST_0:.*]] = "tfl.pseudo_const"() <{value = dense<0.000000e+00> : tensor<4x4xf16>}> : () -> tensor<4x4xf16> + // CHECK-DAG: %[[CST_1:.*]] = "tfl.pseudo_const"() <{value = dense<0.000000e+00> : tensor<4x4xf16>}> : () -> tensor<4x4xf16> + // CHECK-DAG: %[[DQ_0:.*]] = "tfl.dequantize"(%[[CST_0]]) : (tensor<4x4xf16>) -> tensor<4x4xf32> + // CHECK-DAG: %[[DQ_1:.*]] = "tfl.dequantize"(%[[CST_1]]) : (tensor<4x4xf16>) -> tensor<4x4xf32> + // CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg2, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %[[DQ_0]], %[[DQ_1]], %arg0, %arg0, %arg0, %arg0) <{fused_activation_function = "NONE", time_major = true}> : (tensor<4x4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4x4xf32> + + %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf16>} : () -> tensor<4x4xf16> loc("Const") + %1 = "tfl.dequantize"(%0) : (tensor<4x4xf16>) -> tensor<4x4xf32> + %2 = "tfl.unidirectional_sequence_lstm"(%arg2, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %1, %1, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4x4xf32> + func.return %2 : tensor<4x4x4xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir b/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir index 71f717d4c7862b..8ef595c2dcf43e 100644 --- a/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir +++ b/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir @@ -204,7 +204,7 @@ func.func @whileSinkConstant(%arg0: tensor<1x256xf32>) -> tensor<1x256xf32> attr "tfl.yield"(%3) : (tensor) -> () }, { ^bb0(%arg1: tensor, %arg2: tensor<1x256xf32>): - // CHECK: %[[QCONST:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<256x256x!quant.uniform>, value = dense<1> : tensor<256x256xi8>} : () -> tensor<256x256x!quant.uniform> + // CHECK: %[[QCONST:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<256x256x!quant.uniform>, value = dense<1> : tensor<256x256xi8>}> : () -> tensor<256x256x!quant.uniform> // CHECK: %[[CONST:.*]] = arith.constant dense<1> : tensor<256x256xi8> %4 = "tfl.batch_matmul"(%arg2, %cst_0) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x256xi8>) -> tensor<1x256xf32> // CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%arg1, %[[CONST]]) diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index f4aa97069655e8..ccd6b8e559eac8 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h" #include "tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" @@ -150,12 +151,14 @@ void AddPreQuantizationStableHloToTfPasses( // to be consistent with other entrypoints. pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + pass_manager.addNestedPass( + mlir::odml::CreateOutlineCompositesPass()); // Decompose CHLO into StableHLO ops // TODO(b/331843141): There are some CHLO's like TopK which we could instead // lower to TFL ops. mlir::stablehlo::experimental::createChloLegalizeToStablehloPipeline( pass_manager); - + pass_manager.addPass(mlir::odml::CreateTransposeCommuteOpsPass()); // The following two passes find specific uniform quantization patterns in // StableHLO and converts them to TFLite ops that accept or produce uniform // quantized types. They only target a specific set of models that contain diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index c99780603abbb4..8d124af7cb246a 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -61,7 +62,6 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/public/session.h" #include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/schema/schema_generated.h" using mlir::MLIRContext; using mlir::ModuleOp; diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index ac4de7f82b23d0..dd8b345862e3c8 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -50,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/metrics/error_collector.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" #include "tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" @@ -80,7 +81,6 @@ limitations under the License. #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/python/metrics/converter_error_data.pb.h" -#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/tools/optimize/quantize_weights.h" #include "tensorflow/lite/tools/optimize/reduced_precision_support.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index 463b09005544aa..f77912938d8709 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -45,7 +45,7 @@ namespace tensorflow { // file; otherwise, load from a GraphDef. // Setting prune_unused_nodes to true, would prune unreachable nodes if // output_arrays is specified. -tsl::StatusOr> LoadFromGraphdefOrMlirSource( +absl::StatusOr> LoadFromGraphdefOrMlirSource( const std::string& input_filename, bool input_mlir, bool use_splatted_constant, const std::vector& extra_tf_opdefs, const GraphImportConfig& specs, absl::string_view debug_info_file, @@ -56,7 +56,7 @@ tsl::StatusOr> LoadFromGraphdefOrMlirSource( // Load Saved model (either v1 or v2) into MLIR. // 'saved_model_bundle' will be initialized if V1 model was loaded. -tsl::StatusOr> ImportSavedModel( +absl::StatusOr> ImportSavedModel( const std::string& input_filename, int saved_model_version, const std::unordered_set& tags, absl::Span extra_tf_opdefs, diff --git a/tensorflow/compiler/mlir/lite/transforms/analyze_variables.cc b/tensorflow/compiler/mlir/lite/transforms/analyze_variables.cc index 6fd0278bf909e4..39afd416ab1aa2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/analyze_variables.cc +++ b/tensorflow/compiler/mlir/lite/transforms/analyze_variables.cc @@ -16,6 +16,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -92,7 +93,8 @@ void AnalyzeVariablesPass::runOnOperation() { // Note: this might disable native variables in more than needed cases. // TODO(b/189370197): Enhance variable analysis. for (auto operand : op->getOperands()) { - if (getElementTypeOrSelf(operand.getType()).isa()) { + if (mlir::isa( + getElementTypeOrSelf(operand.getType()))) { legalize_to_tfl = false; return WalkResult::interrupt(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc b/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc index 5329274271c55c..3fcd82ef033938 100644 --- a/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc +++ b/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -62,20 +63,20 @@ class DequantizeConverter : public OpRewritePattern { bool allTypesFp = true; bool allTypesQuantizedOrInt = true; for (auto operand : op->getOperands()) { - ShapedType type = operand.getType().template dyn_cast(); + ShapedType type = mlir::dyn_cast(operand.getType()); if (!type) continue; - allTypesFp &= !type.getElementType().isa(); + allTypesFp &= !mlir::isa(type.getElementType()); allTypesQuantizedOrInt &= - (type.getElementType().isa() || - type.getElementType().isa()); + (mlir::isa(type.getElementType()) || + mlir::isa(type.getElementType())); } for (auto result : op->getResults()) { - ShapedType type = result.getType().template cast(); - allTypesFp &= !type.getElementType().isa(); + ShapedType type = mlir::cast(result.getType()); + allTypesFp &= !mlir::isa(type.getElementType()); allTypesQuantizedOrInt &= - (type.getElementType().isa() || - type.getElementType().isa()); + (mlir::isa(type.getElementType()) || + mlir::isa(type.getElementType())); } // If all quantized or floating point then types are consistent. diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index 2f015e61d58fe6..94ed4b1e0340a5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -152,7 +153,7 @@ void DefaultQuantParamsPass::AddToWorkListIfUnquantized( Value value, std::vector *values) { // If the result isn't with float type, this result is an integer tensor and // doesn't require quantization. - auto tensor_type = value.getType().dyn_cast(); + auto tensor_type = mlir::dyn_cast(value.getType()); if (!tensor_type) { // There are none type values. return; @@ -202,9 +203,9 @@ quant::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias( for (int non_bias : non_biases) { Operation *non_bias_define = op->getOperand(non_bias).getDefiningOp(); if (auto dequant = llvm::dyn_cast(non_bias_define)) { - auto non_bias_type = dequant.getInput().getType().cast(); + auto non_bias_type = mlir::cast(dequant.getInput().getType()); auto non_bias_ele_type = - non_bias_type.getElementType().cast(); + mlir::cast(non_bias_type.getElementType()); non_bias_types.push_back(non_bias_ele_type); } else { // The non-bias hasn't been quantized, let's skip this bias. diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc index 8a3abc94e2af57..5cac14867482bb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc +++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/lite/kernels/internal/utils/sparsity_format_converter.h" @@ -92,13 +93,13 @@ float CalculateRandomSparsity(const ElementsAttr& attr, int num_elements = type.getNumElements(); int num_zeros = 0; - if (type.getElementType().isa()) { + if (mlir::isa(type.getElementType())) { for (const auto val : attr.getValues()) { if (val.isZero()) { num_zeros++; } } - } else if (type.getElementType().isa()) { + } else if (mlir::isa(type.getElementType())) { for (const auto val : attr.getValues()) { if (val == 0) { num_zeros++; @@ -144,7 +145,7 @@ float CalculateBlockSparsity(const ElementsAttr& attr, const ShapedType& type, sparsity = GetSparsity(type.getNumElements() - format_converter.GetData().size(), type.getNumElements()); - } else if (type.getElementType().isa()) { + } else if (mlir::isa(type.getElementType())) { tflite::internal::sparsity::FormatConverter format_converter( shape, traversal_order, format, b_size, b_map); std::vector data; @@ -179,10 +180,10 @@ InspectResult InspectWeight( InspectResult result = {}; if (auto cst = dyn_cast(inst)) { attr = cst.getValue(); - type = cst.getType().cast(); + type = mlir::cast(cst.getType()); } else if (auto cst = dyn_cast(inst)) { attr = cst.getValue(); - type = cst.getType().cast(); + type = mlir::cast(cst.getType()); } else { result.can_compress = false; return result; @@ -229,10 +230,10 @@ std::vector BuildSparsityParameterAttribute( ShapedType type; if (auto cst = dyn_cast(inst)) { attr = cst.getValue(); - type = cst.getType().cast(); + type = mlir::cast(cst.getType()); } else if (auto cst = dyn_cast(inst)) { attr = cst.getValue(); - type = cst.getType().cast(); + type = mlir::cast(cst.getType()); } else { assert(false && "Expected a constant-like op"); } @@ -317,10 +318,10 @@ void DenseToSparsePass::runOnOperation() { float ratio_threshold = kBlockOverRandomSparsityRatio; if (isa(inst)) { supported_block_size = sparse_op.GetFloatBlockSize(); - type = dyn_cast(inst).getType().cast(); + type = mlir::cast(dyn_cast(inst).getType()); } else if (isa(inst)) { supported_block_size = sparse_op.GetQuantizedBlockSize(); - type = dyn_cast(inst).getType().cast(); + type = mlir::cast(dyn_cast(inst).getType()); ratio_threshold = kBlockOverRandomSparsityRatioQuant; } else { continue; @@ -341,7 +342,7 @@ void DenseToSparsePass::runOnOperation() { SparsityParameterAttr s_param; if (auto cst = dyn_cast(inst)) { auto attr = cst.getValue(); - auto type = cst.getType().cast(); + auto type = mlir::cast(cst.getType()); if (type.getElementType().isF32()) { std::vector dense_data; dense_data.reserve(type.getNumElements()); @@ -385,7 +386,7 @@ void DenseToSparsePass::runOnOperation() { } } else if (auto cst = dyn_cast(inst)) { auto attr = cst.getValue(); - auto type = cst.getType().cast(); + auto type = mlir::cast(cst.getType()); std::vector dense_data; dense_data.reserve(type.getNumElements()); for (const auto& val : attr.getValues()) diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h index 51068fcf4ac67c..fe8bb7d2ca177f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h +++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -110,7 +111,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( } // Allow dynamic width and height dimensions only. - auto result_ty = op.getResult().getType().template cast(); + auto result_ty = mlir::cast(op.getResult().getType()); if (!result_ty.hasRank() || result_ty.getRank() != 4 || result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3)) { return rewriter.notifyMatchFailure( @@ -187,8 +188,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( // Make sure that the axis in `expand_op` is constant. if (auto const_op = llvm::dyn_cast(expand_op.getDim().getDefiningOp())) { - expand_axis = (*const_op.getValue() - .cast() + expand_axis = (*mlir::cast(const_op.getValue()) .getValues() .begin()) .getSExtValue(); @@ -208,7 +208,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( return rewriter.notifyMatchFailure( squeeze_op, "squeeze dims should have exactly 1 dimension specified"); } - int64_t squeeze_axis = squeeze_dims[0].cast().getInt(); + int64_t squeeze_axis = mlir::cast(squeeze_dims[0]).getInt(); if (squeeze_axis < 0) { // Always squeeze 4D input to 3D input. squeeze_axis += 4; @@ -318,7 +318,8 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( } if (expand_op) { - if (stb_op.getInput().getType().dyn_cast() == nullptr) { + if (mlir::dyn_cast(stb_op.getInput().getType()) == + nullptr) { return rewriter.notifyMatchFailure( stb_op, "SpaceToBatchND op's input should have RankedTensorType"); } @@ -401,7 +402,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( expand_op.setOperand(0, stb_op.getInput()); // Calculate the shape for expand. auto input_shape = - stb_op.getInput().getType().cast().getShape(); + mlir::cast(stb_op.getInput().getType()).getShape(); SmallVector expand_shape(input_shape.begin(), input_shape.end()); expand_shape.insert(expand_shape.begin() + expand_axis, 1); @@ -412,7 +413,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( // Update the conv op's output shape. auto bts_output_shape = - bts_op.getOutput().getType().cast().getShape(); + mlir::cast(bts_op.getOutput().getType()).getShape(); SmallVector conv_result_shape(bts_output_shape.begin(), bts_output_shape.end()); conv_result_shape.insert(conv_result_shape.begin() + expand_axis, 1); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_hashtables.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_hashtables.cc index 252e18e191aea4..5e88048d775532 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_hashtables.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_hashtables.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -142,10 +143,12 @@ bool checkWhetherGraphHasValidStaticLookupTables(ModuleOp module) { // Only allow string -> int64 and int64 -> string mappings due to kernel // capability. - if (!((key_dtype.isa() && value_dtype.isa() && - value_dtype.cast().getWidth() == 64) || - (value_dtype.isa() && key_dtype.isa() && - key_dtype.cast().getWidth() == 64))) { + if (!((mlir::isa(key_dtype) && + mlir::isa(value_dtype) && + mlir::cast(value_dtype).getWidth() == 64) || + (mlir::isa(value_dtype) && + mlir::isa(key_dtype) && + mlir::cast(key_dtype).getWidth() == 64))) { return false; } diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc index e8bae6eb64280f..9b0a80a4f92a71 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc @@ -84,10 +84,10 @@ void LegalizeJaxRandomPass::runOnOperation() { auto func = getOperation(); if (!IsJaxRandomUniform(func) && !IsJaxRandomNormal(func)) return; auto result_tuple_ty = - func.getFunctionType().getResult(0).dyn_cast_or_null(); + mlir::dyn_cast_or_null(func.getFunctionType().getResult(0)); if (!result_tuple_ty) return; if (result_tuple_ty.size() != 1) return; - auto result_ty = result_tuple_ty.getType(0).dyn_cast(); + auto result_ty = mlir::dyn_cast(result_tuple_ty.getType(0)); func.eraseBody(); func.addEntryBlock(); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index cfe9bc754d8077..240773a82a9657 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -73,9 +73,6 @@ def CreateTFCastToInt32Op : NativeCodeCall< def CreateInt32ConstOrCast : NativeCodeCall< "CreateInt32ConstOrCast($0, $_loc, $_builder)">; -def CreateNoneValue : NativeCodeCall< - "$_builder.create($0.getLoc(), $_builder.getUnitAttr())">; - // Creates an int32 constant op from an integer attribute $0. def CreateInt32ConstOpFromIntAttr : NativeCodeCall<"$_builder.create($_loc, DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getI32Type()), {static_cast($0.cast().getInt())}))">; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tensorlist.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tensorlist.cc index d4f58c00eea4f5..10adf2434acbca 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tensorlist.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tensorlist.cc @@ -40,12 +40,12 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/lite/schema/schema_generated.h" namespace { using ::mlir::MLIRContext; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 38a8bffd87bb03..2011b6d33ccd45 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -58,7 +58,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" -#include "xla/status.h" #include "xla/statusor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -96,7 +95,7 @@ class LegalizeTFPass : public impl::LegalizeTFPassBase { // Util that casts 'val' to Int32 by adding a cast Op. Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) { IntegerType new_ele_type = rewriter.getIntegerType(32); - if (auto shaped_type = val.getType().dyn_cast()) { + if (auto shaped_type = mlir::dyn_cast(val.getType())) { ShapedType new_type = RankedTensorType::get(shaped_type.getShape(), new_ele_type); return rewriter.createOrFold(loc, new_type, val, @@ -114,7 +113,7 @@ Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) { // 2. In the default case, cast the `Value` to an int32_t. Value CreateInt32ConstOrCast(Value val, Location loc, PatternRewriter& rewriter) { - if (val.getType().cast().hasStaticShape()) { + if (mlir::cast(val.getType()).hasStaticShape()) { DenseElementsAttr shape_value_attr; if (matchPattern(val, m_Constant(&shape_value_attr))) { SmallVector new_shape_array_i32; @@ -137,7 +136,7 @@ Value CreateInt32ConstOrCast(Value val, Location loc, // Get shape of an operand or result, support both dynamic and static shape. Value GetShape(Value input, Location loc, PatternRewriter& rewriter) { - auto shaped_type = input.getType().cast(); + auto shaped_type = mlir::cast(input.getType()); if (shaped_type.hasStaticShape()) { auto static_shape = shaped_type.getShape(); auto static_shape_type = @@ -271,7 +270,7 @@ bool ConvertTFBatchMatMulOp2TFLFullyConnectedOp(Operation* bmm_op, // Create a tfl.transpose op that performs ZX transpose on `input`. auto create_z_x_transpose_op = [&](Value input) -> Value { - RankedTensorType input_type = input.getType().cast(); + RankedTensorType input_type = mlir::cast(input.getType()); const int input_rank = input_type.getRank(); // Create a 1D I32 tensor for representing the dimension permutation. @@ -364,7 +363,7 @@ LogicalResult ConvertTFMatMulOp::matchAndRewrite( auto rhs = op->getOperand(1); auto transpose = [&](Value input) -> std::pair { RankedTensorType type = - input.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(input.getType()); if (!type || type.getRank() != 2) return {failure(), nullptr}; auto permute_attr = DenseIntElementsAttr::get( @@ -583,15 +582,15 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) { // Verify padding_value is a tensor with all 0s. mlir::Value padding_value = tf_matrix_diag_v2_or_v3_op.getPaddingValue(); mlir::Type element_type = - padding_value.getType().cast().getElementType(); - if (element_type.isa()) { + mlir::cast(padding_value.getType()).getElementType(); + if (mlir::isa(element_type)) { DenseFPElementsAttr padding_attr; if (!matchPattern(padding_value, m_Constant(&padding_attr)) || !padding_attr.isSplat() || !padding_attr.getSplatValue().isZero()) { return false; } - } else if (element_type.isa()) { + } else if (mlir::isa(element_type)) { DenseIntElementsAttr padding_attr; if (!matchPattern(padding_value, m_Constant(&padding_attr)) || !padding_attr.isSplat() || @@ -642,7 +641,7 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern { SmallVector tflite_indices; for (auto index_attr : tflite_indices_attr.getValue()) { - IntegerAttr index = index_attr.cast(); + IntegerAttr index = mlir::cast(index_attr); tflite_indices.push_back(index.getInt()); } @@ -773,13 +772,13 @@ class ApplyExplicitBroadcasting : public OpRewritePattern { SmallVector symbolic_broadcast_shape; // Matches fail when lhs or rhs is unranked tensor. // TODO(b/176202543): Support unranked tensor. - if (!lhs.getType().cast().hasRank() || - !rhs.getType().cast().hasRank()) { + if (!mlir::cast(lhs.getType()).hasRank() || + !mlir::cast(rhs.getType()).hasRank()) { return failure(); } if (!OpTrait::util::getBroadcastedShape( - lhs.getType().cast().getShape(), - rhs.getType().cast().getShape(), + mlir::cast(lhs.getType()).getShape(), + mlir::cast(rhs.getType()).getShape(), symbolic_broadcast_shape)) { return failure(); } @@ -824,13 +823,13 @@ class ApplyExplicitBroadcasting : public OpRewritePattern { auto lhs = op->getOperand(0); auto rhs = op->getOperand(1); - if (!lhs.getType().cast().hasStaticShape() || - !rhs.getType().cast().hasStaticShape()) { + if (!mlir::cast(lhs.getType()).hasStaticShape() || + !mlir::cast(rhs.getType()).hasStaticShape()) { return rewriteOpWithDynamicInput(op, rewriter); } - auto lhs_shape = lhs.getType().cast().getShape(); - auto rhs_shape = rhs.getType().cast().getShape(); + auto lhs_shape = mlir::cast(lhs.getType()).getShape(); + auto rhs_shape = mlir::cast(rhs.getType()).getShape(); if (lhs_shape == rhs_shape) { return failure(); @@ -892,23 +891,23 @@ class ApplyExplicitBroadcasting // Matches fail when lhs|rhs|cond is unranked tensor. // TODO(b/176202543): Support unranked tensor. - if (!lhs.getType().cast().hasRank() || - !rhs.getType().cast().hasRank() || - !cond.getType().cast().hasRank()) { + if (!mlir::cast(lhs.getType()).hasRank() || + !mlir::cast(rhs.getType()).hasRank() || + !mlir::cast(cond.getType()).hasRank()) { return failure(); } // Calculates symbolic broadcast shape that is only used in types. SmallVector symbolic_broadcast_lhs_rhs_shape; if (!OpTrait::util::getBroadcastedShape( - lhs.getType().cast().getShape(), - rhs.getType().cast().getShape(), + mlir::cast(lhs.getType()).getShape(), + mlir::cast(rhs.getType()).getShape(), symbolic_broadcast_lhs_rhs_shape)) { return failure(); } SmallVector symbolic_broadcast_shape; if (!OpTrait::util::getBroadcastedShape( - cond.getType().cast().getShape(), + mlir::cast(cond.getType()).getShape(), symbolic_broadcast_lhs_rhs_shape, symbolic_broadcast_shape)) { return failure(); } @@ -964,15 +963,15 @@ class ApplyExplicitBroadcasting auto rhs = op->getOperand(2); // Should have static shapes to calculate the broadcasted shape. - if (!lhs.getType().cast().hasStaticShape() || - !rhs.getType().cast().hasStaticShape() || - !cond.getType().cast().hasStaticShape()) { + if (!mlir::cast(lhs.getType()).hasStaticShape() || + !mlir::cast(rhs.getType()).hasStaticShape() || + !mlir::cast(cond.getType()).hasStaticShape()) { return rewriteOpWithDynamicInput(op, rewriter); } - auto lhs_shape = lhs.getType().cast().getShape(); - auto rhs_shape = rhs.getType().cast().getShape(); - auto cond_shape = cond.getType().cast().getShape(); + auto lhs_shape = mlir::cast(lhs.getType()).getShape(); + auto rhs_shape = mlir::cast(rhs.getType()).getShape(); + auto cond_shape = mlir::cast(cond.getType()).getShape(); if (lhs_shape == rhs_shape && cond_shape == lhs_shape) { return failure(); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_variables.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_variables.cc index 7098b2f75157da..7742ea06976c00 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_variables.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_variables.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -67,7 +68,7 @@ class LegalizeVariablesPass // If TFLite variable legalization is not allowed, then we skip this pass. if (auto legalize_tfl_variables_attr = module->getAttr(kLegalizeTflVariables)) { - if (!legalize_tfl_variables_attr.cast().getValue()) return; + if (!mlir::cast(legalize_tfl_variables_attr).getValue()) return; } RewritePatternSet patterns(&getContext()); diff --git a/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc b/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc index e212ce16ee6ccd..747e96d40b6850 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc @@ -117,7 +117,7 @@ class LiftFlexCustomOp : public OpRewritePattern { // TODO(b/146131919): correct handling of resource type if (auto tensor_array_v3_op = dyn_cast(tf_op)) { Value handle = tensor_array_v3_op.getHandle(); - auto handle_type = handle.getType().cast(); + auto handle_type = mlir::cast(handle.getType()); if (handle_type.getElementType().isInteger(/*width=*/32)) { Type resource_tensor_type = handle_type.clone(TF::ResourceType::get(rewriter.getContext())); @@ -225,8 +225,8 @@ class LiftFlexCustomOp : public OpRewritePattern { return emitError(loc, mlir_attr.status().message()); } if (absl::StrContains(op_name, "Dataset") && - mlir_attr->isa()) { - mlir_attr = mlir_attr->cast().getName(); + mlir::isa(*mlir_attr)) { + mlir_attr = mlir::cast(*mlir_attr).getName(); } attributes.push_back(builder.getNamedAttr(attr_name, *mlir_attr)); } diff --git a/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc b/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc index a8adac41229277..7fea1e395ea209 100644 --- a/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc +++ b/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc @@ -94,7 +94,7 @@ LogicalResult ModifyIONodesPass::SetupInputOutputTypesIfNull( LogicalResult ModifyIONodesPass::ModifyInputNodes( func::FuncOp func, llvm::SmallVectorImpl& new_input_types, OpBuilder builder) { - if (input_type.isa()) { + if (mlir::isa(input_type)) { return success(); } @@ -151,7 +151,7 @@ LogicalResult ModifyIONodesPass::ModifyOutputNodes( auto* terminator = block.getTerminator(); builder.setInsertionPoint(terminator); - if (output_type.isa()) { + if (mlir::isa(output_type)) { return success(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 1fc84007a64cce..606be04a0f7d6b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -81,7 +81,7 @@ constexpr char kRelu6[] = "RELU6"; constexpr char kRelu1[] = "RELU_N1_TO_1"; ElementsAttr FlattenTo1D(Attribute a) { - auto elements = a.cast(); + auto elements = mlir::cast(a); const std::array flattened_shape = {elements.getNumElements()}; auto new_type = RankedTensorType::get(flattened_shape, elements.getType().getElementType()); @@ -91,8 +91,8 @@ ElementsAttr FlattenTo1D(Attribute a) { // This assumes that the bias is of shape NxCx1x1 and doesn't require transpose // Its corresponding constraint is optimize_patterns.td:IsBiasShape() ElementsAttr ReshapeNCHWBiasToNHWC(Value v, Attribute a) { - auto elements = a.cast(); - auto shape = v.getType().cast().getShape(); + auto elements = mlir::cast(a); + auto shape = mlir::cast(v.getType()).getShape(); if (shape.size() != 4 || shape[2] != 1 || shape[3] != 1) return elements; const std::array new_shape = {shape[0], shape[2], shape[3], shape[1]}; @@ -105,15 +105,16 @@ bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) { if (axis.getNumElements() == 0) { return false; } - if (sq_op.getType().cast().getRank() - 1 == + if (mlir::cast(sq_op.getType()).getRank() - 1 == *axis.getValues().begin() || *axis.getValues().begin() == -1) { return true; } - if (sq_op.getType().cast().getRank() != axis.getNumElements()) { + if (mlir::cast(sq_op.getType()).getRank() != + axis.getNumElements()) { return false; } - auto shape = sq_op.getType().cast(); + auto shape = mlir::cast(sq_op.getType()); SmallVector elems{axis.getValues().begin(), axis.getValues().end()}; for (int i = 0; i < shape.getRank(); ++i) { @@ -144,9 +145,10 @@ class OptimizePass : public impl::OptimizePassBase { // is equal to the non-contracting dimension after a reshape bool BroadcastDimsProductEqual(Value input, Value output, size_t agg_start_idx) { - ArrayRef input_shape = input.getType().cast().getShape(); + ArrayRef input_shape = + mlir::cast(input.getType()).getShape(); ArrayRef output_shape = - output.getType().cast().getShape(); + mlir::cast(output.getType()).getShape(); int64_t agg_value = 1; for (size_t i = agg_start_idx; i < input_shape.size() - 1; ++i) { @@ -166,7 +168,7 @@ bool IsBroadcastableElementsAttrAndType(Type a, Type b) { // broadcast-compatible with `b`. bool OperandsBroadcastToOutputType(Type a, Type b, Type expected_output) { Type output_element_type = - expected_output.cast().getElementType(); + mlir::cast(expected_output).getElementType(); Type broadcasted_type = OpTrait::util::getBroadcastedType(a, b, output_element_type); return broadcasted_type != Type() && broadcasted_type == expected_output; @@ -175,8 +177,8 @@ bool OperandsBroadcastToOutputType(Type a, Type b, Type expected_output) { // Returns whether if `type1` dimensions are the same as the ending dimensions // of `type2`. This is more restricted than broadcastable. bool IsTailOfShape(Type type1, Type type2) { - auto tail_type = type1.dyn_cast(); - auto full_type = type2.dyn_cast(); + auto tail_type = mlir::dyn_cast(type1); + auto full_type = mlir::dyn_cast(type2); if (!tail_type || !full_type || !tail_type.hasRank() || !full_type.hasRank() || tail_type.getRank() > full_type.getRank()) return false; @@ -189,8 +191,8 @@ bool IsTailOfShape(Type type1, Type type2) { // the reduced `type1` dimensions are the same as the ending dimensions // of `type2`. bool IsReducedTailOfShape(Type type1, Type type2) { - auto tail_type = type1.dyn_cast(); - auto full_type = type2.dyn_cast(); + auto tail_type = mlir::dyn_cast(type1); + auto full_type = mlir::dyn_cast(type2); if (!tail_type || !full_type || !tail_type.hasRank() || !full_type.hasRank()) return false; @@ -211,10 +213,10 @@ bool IsReducedTailOfShape(Type type1, Type type2) { // elements in type2. This is a required condition to flatten type2 to form a // 1D array and allow the binaryOp handle the broadcasting implicitly. bool IsLastDimEqualToNumElements(Type type1, Type type2) { - return (type1.cast().getRank() >= 1 && - type1.cast().getDimSize( - type1.cast().getRank() - 1) == - type2.cast().getNumElements()); + return (mlir::cast(type1).getRank() >= 1 && + mlir::cast(type1).getDimSize( + mlir::cast(type1).getRank() - 1) == + mlir::cast(type2).getNumElements()); } bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef filter_shape, @@ -249,20 +251,21 @@ bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef filter_shape, bool CanFuseConvOrDepthwiseConv(Value filter, Attribute val, bool is_depthwise) { - const auto elements = val.dyn_cast(); + const auto elements = mlir::dyn_cast(val); if (!elements) { return false; } const auto elements_shape = elements.getType().getShape(); - const auto filter_shape = filter.getType().cast().getShape(); + const auto filter_shape = mlir::cast(filter.getType()).getShape(); return CanFuseConvOrDepthwiseConvShapes(filter_shape, elements_shape, is_depthwise); } bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val, bool is_depthwise) { - if (const auto elements = val.dyn_cast()) { - if (const auto filter_elements = filter.dyn_cast()) { + if (const auto elements = mlir::dyn_cast(val)) { + if (const auto filter_elements = + mlir::dyn_cast(filter)) { return CanFuseConvOrDepthwiseConvShapes( filter_elements.getType().getShape(), elements.getType().getShape(), is_depthwise); @@ -277,8 +280,8 @@ bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val, bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params, DenseIntElementsAttr indices, Type output_type) { - auto params_type = params.getType().dyn_cast(); - auto indices_type = indices.getType().dyn_cast(); + auto params_type = mlir::dyn_cast(params.getType()); + auto indices_type = mlir::dyn_cast(indices.getType()); // Checks the shape of `params` is [n, ...], shape of `indices` is [n, 1]. 2D // `indices` means it gets the first row of `params`. As long as indices // iterate the first row of `params`, the output is identical to input. @@ -306,8 +309,8 @@ bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params, // for each dim i, the output tensor is identical to `input`. bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) { // Checks if `begin` and `size` are i32 or i64. - auto begin_attr = begin.dyn_cast(); - auto size_attr = size.dyn_cast(); + auto begin_attr = mlir::dyn_cast(begin); + auto size_attr = mlir::dyn_cast(size); if (!begin_attr || !size_attr) { return false; } @@ -323,7 +326,7 @@ bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) { // Checks if `input` is ranked and its rank is equal to number of elements in // `begin` and `size`. - auto input_ty = input.getType().cast(); + auto input_ty = mlir::cast(input.getType()); if (!input_ty.hasRank()) { return false; } @@ -348,7 +351,7 @@ bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) { // Expand Attribute 'a' to 4D with all 1s except 1 dimension. // Which dimension depends on 'is_depthwise' is true or false. ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) { - auto elements = a.dyn_cast(); + auto elements = mlir::dyn_cast(a); auto shape = elements.getType().getShape(); if (!shape.empty()) { // Checks that elements are essentially 1d. @@ -377,46 +380,19 @@ TypeAttr RescaleQtype(Type input, Attribute factor) { return quant::RescaleQuantizedType(input, factor); } -// Utility function to map final permutation to initial permutation -// initial -> permutation1 -> permutation2 -> final -DenseElementsAttr RemapPermutation(Value permutation1, Value permutation2) { - SmallVector initial_permutation; - DenseElementsAttr perm1_const; - DenseElementsAttr perm2_const; - - SmallVector new_permutation; - if (matchPattern(permutation1, m_Constant(&perm1_const)) && - matchPattern(permutation2, m_Constant(&perm2_const))) { - for (int32_t idx = 0; idx < perm1_const.getNumElements(); ++idx) { - initial_permutation.push_back(idx); - } - for (auto perm : perm2_const.getValues()) { - new_permutation.push_back( - initial_permutation[perm1_const - .getValues()[perm.getSExtValue()] - .getSExtValue()]); - } - } - - return mlir::DenseElementsAttr::get( - RankedTensorType::get( - {static_cast(new_permutation.size())}, - mlir::IntegerType::get(permutation1.getContext(), 32)), - llvm::ArrayRef(new_permutation)); -} - // Returns `true` if reducing `axes` in `input` with `keep_dims=true` results in // the specified `shape` and `false` otherwise. static bool ShapeMatchesReduceWithKeepAxes(Value input, const mlir::Attribute &axes, const mlir::Attribute &shape) { - RankedTensorType type = input.getType().dyn_cast_or_null(); + RankedTensorType type = + mlir::dyn_cast_or_null(input.getType()); if (!type) return false; DenseIntElementsAttr axes_attr = - axes.dyn_cast_or_null(); + mlir::dyn_cast_or_null(axes); DenseIntElementsAttr shape_attr = - shape.dyn_cast_or_null(); + mlir::dyn_cast_or_null(shape); if (!axes_attr || !shape_attr) return false; if (shape_attr.getNumElements() != type.getRank()) return false; @@ -441,12 +417,12 @@ static bool ShapeMatchesReduceWithKeepAxes(Value input, static bool AreInputDimensionsOneInAxes(Value input, const mlir::Attribute &axes) { RankedTensorType input_type = - input.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(input.getType()); if (!input_type) return false; auto type_shape = input_type.getShape(); DenseIntElementsAttr axes_attr = - axes.dyn_cast_or_null(); + mlir::dyn_cast_or_null(axes); if (!axes_attr) return false; for (auto a : axes_attr.getValues()) { @@ -467,7 +443,7 @@ static bool AreInputDimensionsOneInAxes(Value input, } static bool FloatValueEquals(const Attribute &attr, double value) { - auto fp_attr = attr.dyn_cast_or_null(); + auto fp_attr = mlir::dyn_cast_or_null(attr); if (!fp_attr) return false; if (fp_attr.isSplat()) { @@ -482,12 +458,12 @@ static bool FloatValueEquals(const Attribute &attr, double value) { // to `raw_value`. template bool IsConstantValueOf(mlir::TypedAttr value, T raw_value) { - auto element_type = value.getType().cast().getElementType(); + auto element_type = mlir::cast(value.getType()).getElementType(); - if (element_type.isa()) { + if (mlir::isa(element_type)) { return FloatValueEquals(value, raw_value); - } else if (element_type.isa()) { - auto int_attr = value.dyn_cast_or_null(); + } else if (mlir::isa(element_type)) { + auto int_attr = mlir::dyn_cast_or_null(value); if (!int_attr) return false; if (int_attr.isSplat()) { @@ -502,13 +478,13 @@ bool IsConstantValueOf(mlir::TypedAttr value, T raw_value) { // Returns true if the value's element type is F32. bool IsF32Value(Value value) { - return value.getType().cast().getElementType().isF32(); + return mlir::cast(value.getType()).getElementType().isF32(); } // Returns the number of elements in attr if it is a static shape, 1 otherwise, // as an unranked int32 Attribute. TypedAttr GetNumElementsOrOne(Type type) { - auto shaped_type = type.cast(); + auto shaped_type = mlir::cast(type); int32_t num_elements = shaped_type.hasStaticShape() ? shaped_type.getNumElements() : 1; @@ -523,7 +499,7 @@ TypedAttr GetNumElementsOrOne(Type type) { Value ReshapeValueDroppingLastDim(OpBuilder &builder, Value value) { // This function is always guarded with HasTrivialShapeExceptSecondLastDim(), // so we could cast safely here. - auto type = value.getType().cast(); + auto type = mlir::cast(value.getType()); SmallVector new_shape; if (type.hasStaticShape()) { for (int64_t dim : type.getShape().drop_back()) { @@ -543,7 +519,7 @@ Value ReshapeValueDroppingLastDim(OpBuilder &builder, Value value) { // Returns true if val has a static shape and the last dimension equals 1. bool IsLastDimensionEqualOne(Value val) { - const auto val_type = val.getType().cast(); + const auto val_type = mlir::cast(val.getType()); if (!val_type.hasStaticShape()) return false; const auto val_shape = val_type.getShape(); if (val_shape.empty()) return false; @@ -577,7 +553,7 @@ bool HasOneUseOrUsedByOnlyBinaryOps(Value out_value) { // // If such a value is used in an Equal operator, it can be replaced with OneHot. bool IsOneHotIndexAttribute(Attribute attr) { - const auto dense_attr = attr.dyn_cast_or_null(); + const auto dense_attr = mlir::dyn_cast_or_null(attr); if (!dense_attr) { return false; } @@ -602,7 +578,7 @@ bool IsOneHotIndexAttribute(Attribute attr) { } Value Get1DShapeValue(OpBuilder &builder, Value value) { - auto type = value.getType().cast(); + auto type = mlir::cast(value.getType()); if (!type.hasStaticShape()) { return nullptr; } @@ -614,11 +590,11 @@ Value Get1DShapeValue(OpBuilder &builder, Value value) { } Type GetEmbeddingLookupShape(Value lookup, Value value) { - auto lookup_type = lookup.getType().cast(); + auto lookup_type = mlir::cast(lookup.getType()); if (!lookup_type.hasStaticShape()) { return nullptr; } - auto value_type = value.getType().cast(); + auto value_type = mlir::cast(value.getType()); if (!value_type.hasStaticShape() || value_type.getRank() != 2) { return nullptr; } @@ -665,7 +641,7 @@ bool IsF32Splat(Attribute input_splat) { // Attribute holding a single value of float type. If attr has no elements, the // result is 0.0f. TypedAttr ConvertSingleElementAttrToFloatAttr(Attribute attr) { - const auto dense_fp_attr = attr.dyn_cast_or_null(); + const auto dense_fp_attr = mlir::dyn_cast_or_null(attr); if (dense_fp_attr) { // Already float => return return dense_fp_attr; @@ -673,7 +649,7 @@ TypedAttr ConvertSingleElementAttrToFloatAttr(Attribute attr) { OpBuilder builder(attr.getContext()); - const auto dense_int_attr = attr.dyn_cast(); + const auto dense_int_attr = mlir::dyn_cast(attr); const auto int_values = dense_int_attr.getValues(); float float_val = 0.0f; if (!int_values.empty()) { @@ -793,9 +769,7 @@ struct SqueezeReshapesAroundBroadcastOp // Pattern is applied only if the broadcast_to shape has more than 5 // dimensions. - if (tfl_broadcast_to_op.getShape() - .getType() - .cast() + if (mlir::cast(tfl_broadcast_to_op.getShape().getType()) .getNumElements() < 6) { return rewriter.notifyMatchFailure(loc, "Not supported broadcast_to shape"); @@ -831,7 +805,7 @@ struct SqueezeReshapesAroundBroadcastOp // Calculate the number of extra leading and trailing 1s in the // broadcast_op output. auto broadcast_output_shapetype = - tfl_broadcast_to_op.getOutput().getType().cast(); + mlir::cast(tfl_broadcast_to_op.getOutput().getType()); int num_leading_broadcast_dims = GetNumLeadingOnes(broadcast_output_shapetype); int num_trailing_broadcast_dims = @@ -839,9 +813,7 @@ struct SqueezeReshapesAroundBroadcastOp // Get the new shape for the inner reshape_op after removing the extra 1s. llvm::SmallVector new_reshape_shape_i32{ - inner_reshape_op.getOutput() - .getType() - .cast() + mlir::cast(inner_reshape_op.getOutput().getType()) .getShape() .drop_back(num_trailing_broadcast_dims) .drop_front(num_leading_broadcast_dims)}; @@ -886,11 +858,11 @@ struct ConvertTFLBroadcastToMulOp LogicalResult matchAndRewrite(TFL::BroadcastToOp tfl_broadcast_to_op, PatternRewriter &rewriter) const override { auto input_type = - tfl_broadcast_to_op.getInput().getType().cast(); + mlir::cast(tfl_broadcast_to_op.getInput().getType()); auto output_type = - tfl_broadcast_to_op.getOutput().getType().cast(); + mlir::cast(tfl_broadcast_to_op.getOutput().getType()); auto shape_type = - tfl_broadcast_to_op.getShape().getType().cast(); + mlir::cast(tfl_broadcast_to_op.getShape().getType()); Type element_type = input_type.getElementType(); auto loc = tfl_broadcast_to_op->getLoc(); @@ -909,7 +881,7 @@ struct ConvertTFLBroadcastToMulOp // Allow lowering when the input's elements type is F32, BFloat16, I32 or // I16. - if (!(element_type.isa() || + if (!(mlir::isa(element_type) || element_type.isInteger(32) || element_type.isInteger(16))) return rewriter.notifyMatchFailure(loc, "element_type_not_supported"); @@ -986,7 +958,7 @@ struct FuseAddAndStridedSlice : public OpRewritePattern { return failure(); mlir::TensorType constant_val_type = - constant_val.getType().cast(); + mlir::cast(constant_val.getType()); // If it's not 1D or 0D (which can be broadcasted to 1D), reject the // matching. if (constant_val_type.getRank() > 1) { @@ -994,14 +966,14 @@ struct FuseAddAndStridedSlice : public OpRewritePattern { } mlir::RankedTensorType end_type = - strided_slice_op.getEnd().getType().dyn_cast(); + mlir::dyn_cast(strided_slice_op.getEnd().getType()); // begin, end and strides are Rank 1 tensors with one element per dimension // of input. int64_t num_dims = end_type.getShape()[0]; DenseElementsAttr new_added_value = added_value.reshape(RankedTensorType::get( {num_dims}, - added_value.getType().cast().getElementType())); + mlir::cast(added_value.getType()).getElementType())); ::mlir::arith::ConstantOp new_end = rewriter.create( strided_slice_op.getEnd().getLoc(), new_added_value); @@ -1183,7 +1155,7 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { add_op.getLhs().getDefiningOp()); if (!fc_op) return failure(); - auto constant_val_type = constant_val.getType().cast(); + auto constant_val_type = mlir::cast(constant_val.getType()); // In TFLite FullyConnect definition, bias must be a 1D tensor where // the number of elements is equal to the number of channels. @@ -1199,7 +1171,7 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { Value filter = fc_op.getFilter(); Value bias = fc_op.getBias(); ElementsAttr bias_value; - const bool is_none_bias = bias.getType().isa(); + const bool is_none_bias = mlir::isa(bias.getType()); if (fc_op.getFusedActivationFunction() != "NONE") return failure(); if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value))) @@ -1212,7 +1184,7 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { // to properly broadcast the scalar to `{num_channels}` shape. // Get the number of channels if possible. - auto filter_type = filter.getType().dyn_cast(); + auto filter_type = mlir::dyn_cast(filter.getType()); // Filter must be a `2D` tensor with `{num_channels, num_features}` // shape. The following check is rejecting unknown rank (-1). if (filter_type == nullptr || filter_type.getRank() != 2) { @@ -1287,14 +1259,14 @@ struct FuseAddAndFullyConnected // Don't match adds where the added constant is not 1D. { - auto addend_shape = add_op.getRhs().getType().cast(); + auto addend_shape = mlir::cast(add_op.getRhs().getType()); if (!addend_shape.hasStaticShape()) return failure(); if (addend_shape.getShape().size() != 1) return failure(); } // Calculate new bias. Generate a new FC; it will be constant folded. auto old_bias = fc_op.getBias(); - if (!old_bias || old_bias.getType().isa()) { + if (!old_bias || mlir::isa(old_bias.getType())) { // TODO(b/180752069): Figure out new bias' type when old bias is empty. return failure(); } @@ -1358,7 +1330,7 @@ struct FuseMulAndFullyConnected // Don't match muls where the multiplier constant is not 1D. { - auto multiplier_shape = mul_op.getRhs().getType().cast(); + auto multiplier_shape = mlir::cast(mul_op.getRhs().getType()); if (!multiplier_shape.hasStaticShape()) return failure(); if (multiplier_shape.getShape().size() != 1) return failure(); } @@ -1464,7 +1436,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { Value bias = fc_op.getBias(); ElementsAttr cst_tmp; if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure(); - if (!bias.getType().isa() && + if (!mlir::isa(bias.getType()) && !matchPattern(bias, m_Constant(&cst_tmp))) return failure(); if (fc_op.getFusedActivationFunction() != "NONE") return failure(); @@ -1494,7 +1466,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { // Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands, // TF::MulOp is used to fold the constant. // TODO(b/139192933): switch to the TFL constant folding - auto filter_type = filter.getType().cast(); + auto filter_type = mlir::cast(filter.getType()); if (filter_type.hasStaticShape()) { auto size = filter_type.getNumElements() * filter_type.getElementTypeBitWidth(); @@ -1506,7 +1478,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { rewriter.create(mul_op.getLoc(), filter, new_const_val) .getZ(); // If bias isn't None, it needs to be multiplied as well. - if (!bias.getType().isa()) { + if (!mlir::isa(bias.getType())) { bias = rewriter.create(mul_op.getLoc(), bias, constant_val) .getZ(); } @@ -1585,7 +1557,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern { // weight constant ElementsAttr cst_tmp; if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure(); - if (!bias.getType().isa() && + if (!mlir::isa(bias.getType()) && !matchPattern(bias, m_Constant(&cst_tmp))) return failure(); if (fc_op.getFusedActivationFunction() != "NONE") return failure(); @@ -1607,7 +1579,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern { } // Make sure that the fused bias will be a 1D tensor. - auto gamma_shape = gamma.getType().cast(); + auto gamma_shape = mlir::cast(gamma.getType()); if (!gamma_shape.hasRank() || gamma_shape.getRank() != 1) { return failure(); } @@ -1623,7 +1595,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern { new_filter, new_qtype); // If bias isn't None, it needs to be multiplied as well. - if (!bias.getType().isa()) { + if (!mlir::isa(bias.getType())) { rewriter.setInsertionPoint(fc_op); auto new_bias = rewriter.create(loc, bias, gamma); fc_op.getOperation()->replaceUsesOfWith(bias, new_bias); @@ -1674,7 +1646,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { } filter = q.getInput(); } - if (!bias.getType().isa() && + if (!mlir::isa(bias.getType()) && !matchPattern(bias, m_Constant(&bias_cst))) return failure(); auto binary_op_activation_func = @@ -1705,7 +1677,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { // The new bias should be a 1-D tensor with length equals to the bias // dimension of the weight. SmallVector new_bias_values; - if (bias.getType().isa()) { // none bias, a list of zeros + if (mlir::isa(bias.getType())) { // none bias, a list of zeros new_bias_values.resize(bias_size, APFloat::getZero(cst_value.getSemantics())); } else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it @@ -1806,12 +1778,11 @@ struct ScalarizeSplatConstantForBroadcastableOps } constexpr int kSplatOperandIndex = 1; - auto result_type = - binary_op.getResult().getType().template cast(); + auto result_type = mlir::cast(binary_op.getResult().getType()); mlir::Value non_splat_operand = binary_op.getOperand(1 - kSplatOperandIndex); auto non_splat_operand_type = - non_splat_operand.getType().cast(); + mlir::cast(non_splat_operand.getType()); // If the other operand's shape does not equal to the result shape, then we // cannot scalarize the splat constant because the result shape relies on // the splat constant op's shape for broadcasting. @@ -1850,10 +1821,11 @@ struct ScalarizeSplatConstantForBroadcastableOps if (!matchPattern(value, m_Constant(elements_attr))) { return false; } - auto element_type = value.getType().cast().getElementType(); + auto element_type = + mlir::cast(value.getType()).getElementType(); // Ignore per-axis quantized constants because after converting to scalar, // we will lose per-axis qantization parameter. - if (element_type.isa()) { + if (mlir::isa(element_type)) { return false; } if (IsScalar(value)) { @@ -1864,7 +1836,7 @@ struct ScalarizeSplatConstantForBroadcastableOps // If this type is a scalar shaped type. bool IsScalar(mlir::Value value) const { - auto type = value.getType().dyn_cast(); + auto type = mlir::dyn_cast(value.getType()); if (!type) { return false; } @@ -1883,7 +1855,7 @@ struct ScalarizeSplatConstantForBroadcastableOps DenseElementsAttr value; // Check that bias are constants if not none. Value bias = affine_op->getOperand(2); - if (!bias.getType().isa() && + if (!mlir::isa(bias.getType()) && !matchPattern(bias, m_Constant(&value))) { return false; } @@ -1896,7 +1868,7 @@ struct ScalarizeSplatConstantForBroadcastableOps // We can only fuse F32/BF16. auto is_fusable_type = [](Type t) { Type element_type = t; - if (auto shaped_type = t.dyn_cast()) { + if (auto shaped_type = mlir::dyn_cast(t)) { element_type = shaped_type.getElementType(); } return element_type.isBF16() || element_type.isF32(); @@ -1920,68 +1892,6 @@ using ScalarizeSplatConstantForMul = using ScalarizeSplatConstantForDiv = ScalarizeSplatConstantForBroadcastableOps; -struct ConvertTrivialTransposeOpToReshapeOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op, - PatternRewriter &rewriter) const override { - auto input_type = transpose_op.getInput().getType().cast(); - auto output_type = transpose_op.getOutput().getType().cast(); - // It's possible to know if the transformation is safe only if the input - // & output shapes are fully known and permutation is a constant. - if (!input_type.hasStaticShape() || !output_type.hasStaticShape()) - return failure(); - Value perm = transpose_op.getPerm(); - DenseElementsAttr perm_values_attr; - if (!matchPattern(perm, m_Constant(&perm_values_attr))) return failure(); - - auto input_shape = input_type.getShape(); - SmallVector perm_values; - for (const auto &dim : perm_values_attr.getValues()) - perm_values.push_back(dim.getSExtValue()); - - // This should never happen unless the input graph is malformed. - if (input_shape.size() != perm_values.size()) { - transpose_op.emitError( - "TransposeOP has inconsistent input and perm values."); - } - - SmallVector old_major_index_ordering; - SmallVector new_major_index_ordering; - for (int i = 0, end = input_shape.size(); i < end; i++) { - if (input_shape[i] != 1) { - old_major_index_ordering.push_back(i); - } - - if (input_shape[perm_values[i]] != 1) { - new_major_index_ordering.push_back(perm_values[i]); - } - } - if (old_major_index_ordering != new_major_index_ordering) { - return failure(); - } - - // Rewrite. - Location loc = transpose_op.getLoc(); - - SmallVector output_shape_values; - for (auto dim : output_type.getShape()) { - output_shape_values.push_back( - ShapedType::isDynamic(dim) ? -1 : static_cast(dim)); - } - - auto new_shape = rewriter.create( - loc, GetI32ElementsAttr(output_shape_values, &rewriter)); - - rewriter.replaceOpWithNewOp( - transpose_op, transpose_op.getOutput().getType(), - transpose_op.getInput(), new_shape); - - return success(); - } -}; - // Remove Reshape before FullyConnected when `keep_num_dims=false` and Reshape // does not alter the last dimension as FullyConnected will collapse all other // dimensions into a single dimension. For example, @@ -2002,10 +1912,9 @@ struct RemoveReshapeBeforeFullyConnected LogicalResult matchAndRewrite(TFL::FullyConnectedOp fully_connected_op, PatternRewriter &) const override { auto input = fully_connected_op.getInput(); - auto input_ty = input.getType().dyn_cast(); - auto output_ty = fully_connected_op.getOutput()[0] - .getType() - .template dyn_cast(); + auto input_ty = mlir::dyn_cast(input.getType()); + auto output_ty = + mlir::dyn_cast(fully_connected_op.getOutput()[0].getType()); if (!input_ty.hasStaticShape() || fully_connected_op.getWeightsFormat() != "DEFAULT" || fully_connected_op.getKeepNumDims() || !output_ty.hasStaticShape() || @@ -2018,7 +1927,7 @@ struct RemoveReshapeBeforeFullyConnected // Check if the last dimension does not change after reshape. auto reshape_input = reshape_op.getInput(); - auto reshape_input_ty = reshape_input.getType().dyn_cast(); + auto reshape_input_ty = mlir::dyn_cast(reshape_input.getType()); if (!reshape_input_ty.hasStaticShape() || input_ty.getRank() == 0 || reshape_input_ty.getRank() == 0 || input_ty.getDimSize(input_ty.getRank() - 1) != @@ -2061,9 +1970,9 @@ struct RemoveReshapeAfterFullyConnected if (!reshape_op.getInput().hasOneUse()) return failure(); auto input_shape = - fully_connected_op.getInput().getType().cast(); - auto output_shape = fully_connected_op.getType(0).cast(); - auto reshape_shape = reshape_op.getType().cast(); + mlir::cast(fully_connected_op.getInput().getType()); + auto output_shape = mlir::cast(fully_connected_op.getType(0)); + auto reshape_shape = mlir::cast(reshape_op.getType()); if (!input_shape.hasStaticShape() || !output_shape.hasStaticShape() || !reshape_shape.hasStaticShape()) return failure(); @@ -2128,7 +2037,7 @@ struct FuseUnpackAndConcatToReshape } } - auto output_type = concat_op.getType().cast(); + auto output_type = mlir::cast(concat_op.getType()); if (!output_type.hasStaticShape()) { return failure(); } @@ -2188,8 +2097,8 @@ struct OptimizeTopK : public OpRewritePattern { // for last dimension. // It can be done by verifying the number of elements: // i.e., num_input/input_last_dim = num_result/k - auto input_ty = value.getType().dyn_cast_or_null(); - auto result_ty = slice_op.getType().dyn_cast(); + auto input_ty = mlir::dyn_cast_or_null(value.getType()); + auto result_ty = mlir::dyn_cast(slice_op.getType()); if (!input_ty || !result_ty) return std::nullopt; if (!input_ty.hasStaticShape() || !result_ty.hasStaticShape()) return std::nullopt; @@ -2230,8 +2139,8 @@ struct OptimizeTopK : public OpRewritePattern { Value k_cst = rewriter.create( op.getLoc(), DenseElementsAttr::get(k_ty, k)); // Compute new result types. - auto values_ty = values.getType().dyn_cast(); - auto indices_ty = indices.getType().dyn_cast(); + auto values_ty = mlir::dyn_cast(values.getType()); + auto indices_ty = mlir::dyn_cast(indices.getType()); auto shape = std::vector(); for (auto d : values_ty.getShape().drop_back()) { shape.push_back(d); @@ -2439,7 +2348,7 @@ struct FuseLogSoftmax : public OpRewritePattern { if (!sum_op || !sum_op.getKeepDims() || !isSupportedAxis( sum_op.getAxes(), - sum_op.getOperand(0).getType().cast().getRank())) { + mlir::cast(sum_op.getOperand(0).getType()).getRank())) { return failure(); } if (!sum_op->hasOneUse()) { @@ -2466,10 +2375,10 @@ struct FuseLogSoftmax : public OpRewritePattern { parent_sub_op.getRhs().getDefiningOp()); if (!reduce_max_op || !reduce_max_op->hasOneUse() || !reduce_max_op.getKeepDims() || - !isSupportedAxis(reduce_max_op.getAxes(), reduce_max_op.getOperand(0) - .getType() - .cast() - .getRank())) { + !isSupportedAxis( + reduce_max_op.getAxes(), + mlir::cast(reduce_max_op.getOperand(0).getType()) + .getRank())) { return failure(); } @@ -2562,7 +2471,7 @@ struct UndoBroadcastFullyConnectedBiasAddWithQDQs } auto bias_type = bias_op.getType(); - auto bias_rank = bias_type.cast().getRank(); + auto bias_rank = mlir::cast(bias_type).getRank(); if (bias_rank > 4 || bias_rank < 2) { return failure(); } @@ -2587,8 +2496,8 @@ struct UndoBroadcastFullyConnectedBiasAddWithQDQs q_op.setOperand(new_bias_op); auto new_q_op_type = RankedTensorType::Builder( - q_op.getResult().getType().cast()) - .setShape(new_bias_type.cast().getShape()); + mlir::cast(q_op.getResult().getType())) + .setShape(mlir::cast(new_bias_type).getShape()); q_op.getResult().setType(new_q_op_type); auto attr = TypeAttr::get(q_op.getResult().getType()); q_op.setQtypeAttr(attr); @@ -2596,8 +2505,8 @@ struct UndoBroadcastFullyConnectedBiasAddWithQDQs // Update DequantizeOp's output shape auto new_dq_op_type = RankedTensorType::Builder( - dq_op.getResult().getType().cast()) - .setShape(new_bias_type.cast().getShape()); + mlir::cast(dq_op.getResult().getType())) + .setShape(mlir::cast(new_bias_type).getShape()); dq_op.getResult().setType(new_dq_op_type); // Remove old bias @@ -2655,9 +2564,9 @@ void OptimizePass::runOnOperation() { FuseFullyConnectedAndReluX, FuseBinaryOpToFollowingConv2D, FuseBinaryOpToFollowingDepthwiseConv2D, FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs, - FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp, - RemoveReshapeAfterFullyConnected, RemoveReshapeBeforeFullyConnected, - FuseUnpackAndConcatToReshape, OptimizeTopK, FuseAddAndStridedSlice, + FuseDepthwiseConv2DAndMulWithQDQs, RemoveReshapeAfterFullyConnected, + RemoveReshapeBeforeFullyConnected, FuseUnpackAndConcatToReshape, + OptimizeTopK, FuseAddAndStridedSlice, FuseReshapeAndTransposeAroundBatchMatmul, FuseTransposeReshapeIntoBatchMatmul>(ctx); if (!this->disable_fuse_mul_and_fc_) { diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.cc index 5b696b52db4b2e..0eacfcb8ef09f0 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.cc @@ -94,7 +94,8 @@ struct ConvertBatchMatMulOp2FullyConnectedOp // Create a tfl.transpose op that performs ZX transpose on `input`. auto create_z_x_transpose_op = [&](Value input) -> Value { - RankedTensorType input_type = input.getType().cast(); + RankedTensorType input_type = + mlir::cast(input.getType()); const int input_rank = input_type.getRank(); // Create a 1D I32 tensor for representing the dimension permutation. @@ -176,7 +177,7 @@ struct ConvertBatchMatMulOpToReduceSum // the adj(X|Y) attribute, respectively. // So adjX == True indicates [..., c_x, r_x == 1]. llvm::ArrayRef lhs_shape = - bmm_op.getX().getType().cast().getShape(); + mlir::cast(bmm_op.getX().getType()).getShape(); int rX = lhs_shape.size() - 2; int cX = lhs_shape.size() - 1; if (bmm_op.getAdjX()) { @@ -189,7 +190,7 @@ struct ConvertBatchMatMulOpToReduceSum } llvm::ArrayRef rhs_shape = - bmm_op.getY().getType().cast().getShape(); + mlir::cast(bmm_op.getY().getType()).getShape(); int rY = rhs_shape.size() - 1; int cY = rhs_shape.size() - 2; if (bmm_op.getAdjX()) { @@ -210,11 +211,11 @@ struct ConvertBatchMatMulOpToReduceSum private: bool SplatValueEquals(SplatElementsAttr float_or_int, double rhs) const { - if (float_or_int.isa()) { - return float_or_int.cast() + if (mlir::isa(float_or_int)) { + return mlir::cast(float_or_int) .getSplatValue() .isExactlyValue(rhs); - } else if (float_or_int.cast()) { + } else if (mlir::cast(float_or_int)) { return float_or_int.getSplatValue() == static_cast(rhs); } return false; diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc index 7d7ab4b5acd33d..69137210b48ffc 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc @@ -21,12 +21,13 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -110,7 +111,7 @@ class FoldIfOp : public OpRewritePattern { if (!matchPattern(op.getCond(), m_Constant(&cond))) return failure(); // TODO(hinsu): Handle constants that are not scalar booleans. - auto cond_type = cond.getType().dyn_cast(); + auto cond_type = mlir::dyn_cast(cond.getType()); if (!cond_type || !cond_type.getShape().equals({}) || !cond_type.getElementType().isInteger(/*width=*/1)) return failure(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc index 4ce0a3b8c43225..62c2c43778e254 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -66,9 +67,9 @@ struct PushDownDequantize : public OpRewritePattern { // Only push down the dequantize op when the output is smaller, so that it // can have smaller memory usage. auto input_type = - dequantize_op.getOutput().getType().dyn_cast(); - auto output_type = - passthrough_op->getResult(0).getType().dyn_cast(); + mlir::dyn_cast(dequantize_op.getOutput().getType()); + auto output_type = mlir::dyn_cast( + passthrough_op->getResult(0).getType()); if (!input_type || !output_type || get_num_elements(input_type) <= get_num_elements(output_type)) { return failure(); @@ -85,7 +86,7 @@ struct PushDownDequantize : public OpRewritePattern { // Set the input type of the passthrough op and pull it up. Type new_output_type; - if (input_element_type.isa()) { + if (mlir::isa(input_element_type)) { new_output_type = QuantizedType::getQuantizedElementType( dequantize_op.getInput().getType()) .castFromExpressedType(output_type); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 0b068972c8fd30..4353b82e2fb901 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -366,6 +366,15 @@ class ConstDoubleValueLessThan : Constraint< "std::abs(*$0.cast().getValues().begin()) < " # n>>; +// Constraint that the attribute value is negative infinity or negative largest. +// We use both -inf & flt_min due to the forward compatibility. +def ConstAPFloatNegLargestOrNegInfinity : Constraint() && " + "$0.cast().getNumElements() == 1 && " + "(($0.cast().getValues()[0].isLargest() && " + "$0.cast().getValues()[0].isNegative()) || " + "$0.cast().getValues()[0].isNegInfinity())">>; + def L2NormValidReduceIndex : Constraint())">>; @@ -771,9 +780,13 @@ def UndoBroadcastConvBiasAdd : Pat< (HasRankAtLeast<2> $bias), (IsDefinedByConv2DOp $lhs)]>; -// Function to map final permutation to initial permutation -// initial -> permutation1 -> permutation2 -> final -def RemapPermutation: NativeCodeCall<"RemapPermutation($0, $1)">; +// Pattern to convert a trivial transpose op to a reshape op. +def ConvertTrivialTransposeOpToReshapeOp : Pat< + (TFL_TransposeOp:$transpose_op $input, (Arith_ConstantOp:$permutation $p1)), + (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $transpose_op))), + [(IsTransposeTrivial $input, $permutation), + (AnyStaticShapeTensor $input), + (AnyStaticShapeTensor $transpose_op)]>; // Pattern to fuse redundant tanspose op def FoldDoubleTranspose : Pat< @@ -1013,6 +1026,30 @@ def FoldNormalizationIntoSoftmax : Pat< (HasOneUse $sub), (HasOneUse $max)]>; +// Convert softmax(x-reshape(maximum(max(x), -inf))) into softmax(x) as the softmax op already deals +// with the max normalization. This comes from upstream Jax (https://github.com/google/jax/pull/15677) +def FoldNormalizationIntoSoftmaxJaxWithAxisMinus1 : Pat< + (TFL_SoftmaxOp + (TFL_SubOp:$sub $input, + (TFL_ReshapeOp:$reshape + (TFL_MaximumOp:$maximum + (TFL_ReduceMaxOp:$max $max_input, (Arith_ConstantOp I32ElementsAttr: $axes), + ConstBoolAttrFalse), + (Arith_ConstantOp F32ElementsAttr: $threshold) + ), + (Arith_ConstantOp I32ElementsAttr: $shape) + ), + TFL_AF_None), + $beta), + (TFL_SoftmaxOp $input, $beta), + [(IsSame $input, $max_input), + (AxesIsLastDimension $axes, $max_input), + (ConstAPFloatNegLargestOrNegInfinity $threshold), + (HasOneUse $maximum), + (HasOneUse $reshape), + (HasOneUse $sub), + (HasOneUse $max)]>; + def HaveSameType : Constraint>; class AllElementsAreF32 : Constraint; +// Fuse redundant RHS TFL_TransposeOp into TFL_BatchMatMulOp if rhs is constant +// tensor of rank-2. +def FuseTransposeIntoBatchMatMulRHS: Pat< + (TFL_BatchMatMulOp $lhs, + (TFL_TransposeOp (TFL_QConstOp:$input $_, $_), (Arith_ConstantOp:$perm_value $p0)), + $adj_x, $adj_y, $asymmetric_quantize_inputs), + (TFL_FullyConnectedOp + $lhs, + $input, (CreateNoneValue $lhs), TFL_AF_None, TFL_FCWO_Default, + ConstBoolAttrTrue, $asymmetric_quantize_inputs), + [(HasRank<2> $input), + (AreLastTwoDimsTransposed $perm_value), + (IsBoolAttrEqual<"false"> $adj_y)]>; + // Replace conv-->transpose-->add with conv-->add-->transpose // The bias needs only reshape (i.e. ReshapeNCHWBiasToNHWC) and not transpose // because the bias's shape simply changes from NxCx1x1 to Nx1x1xC. diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.td b/tensorflow/compiler/mlir/lite/transforms/passes.td index eefb109d2b966e..b2ab947b3895b3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/transforms/passes.td @@ -397,6 +397,9 @@ def QuantizePass : Pass<"tfl-quantize", "mlir::func::FuncOp"> { "std::string", "Names of location to blocklist from quantization">, Option<"enable_custom_op_weight_only_", "enable-custom-op-weight-only", "std::string", "", "Specifies which custom ops are weight-only.">, + Option<"enable_float16_quantization_", + "enable-float16-quantization", "bool", + "false", "Whether apply float16 quantization. If false, int8 quantization is applied.">, ]; } diff --git a/tensorflow/compiler/mlir/lite/transforms/pin_ops_with_side_effects.cc b/tensorflow/compiler/mlir/lite/transforms/pin_ops_with_side_effects.cc index 1d0cd497b052f3..7baa0136f1c33c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/pin_ops_with_side_effects.cc +++ b/tensorflow/compiler/mlir/lite/transforms/pin_ops_with_side_effects.cc @@ -37,9 +37,9 @@ namespace { #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" bool IsResourceTensor(Value value) { - const auto tensor_type = value.getType().dyn_cast(); + const auto tensor_type = mlir::dyn_cast(value.getType()); return tensor_type && - tensor_type.getElementType().isa(); + mlir::isa(tensor_type.getElementType()); } // The default criterion for operations being considered as causing or being diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 80d7ab24c23316..867eecff15818f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -233,8 +234,8 @@ struct FoldTransposeOp : public OpRewritePattern { DenseIntElementsAttr perm_tensor; if (!matchPattern(op.getPerm(), m_Constant(&perm_tensor))) return failure(); - if (!(getElementTypeOrSelf(op.getOutput().getType())) - .isa()) + if (!mlir::isa( + (getElementTypeOrSelf(op.getOutput().getType())))) return failure(); ElementsAttr input_tensor = qconst_op.getValue(); @@ -244,7 +245,7 @@ struct FoldTransposeOp : public OpRewritePattern { assert(perm_tensor.getType().getNumElements() == num_dimensions); ArrayRef input_shape = input_tensor.getShapedType().getShape(); - auto output_type = op.getOutput().getType().cast(); + auto output_type = mlir::cast(op.getOutput().getType()); SmallVector perm; SmallVector output_shape; @@ -265,9 +266,9 @@ struct FoldTransposeOp : public OpRewritePattern { auto result_type = RankedTensorType::get(output_shape, output_type.getElementType()); auto values_type = RankedTensorType::get( - output_shape, output_type.getElementType() - .cast() - .getStorageType()); + output_shape, + mlir::cast(output_type.getElementType()) + .getStorageType()); rewriter.replaceOpWithNewOp( op, TypeAttr::get(result_type), DenseIntElementsAttr::get(values_type, new_values)); @@ -289,18 +290,18 @@ struct FoldReshapeOp : public OpRewritePattern { if (qconst_op == nullptr) return failure(); auto dense_elements = - qconst_op.getValue().dyn_cast_or_null(); + mlir::dyn_cast_or_null(qconst_op.getValue()); if (dense_elements == nullptr) return failure(); // Handle per tensor cases only. - if (!(getElementTypeOrSelf(op.getType())) - .isa()) { + if (!mlir::isa( + (getElementTypeOrSelf(op.getType())))) { return failure(); } // Remove identity reshape with both static result and input shape. - auto result_type = op.getType().cast(); - auto input_type = op.getInput().getType().cast(); + auto result_type = mlir::cast(op.getType()); + auto input_type = mlir::cast(op.getInput().getType()); // Constant folding // If the result type isn't static, tries to derive the result type from @@ -318,9 +319,9 @@ struct FoldReshapeOp : public OpRewritePattern { RankedTensorType::get(shape_data, input_type.getElementType()); } auto values_type = RankedTensorType::get( - result_type.getShape(), result_type.getElementType() - .cast() - .getStorageType()); + result_type.getShape(), + mlir::cast(result_type.getElementType()) + .getStorageType()); DenseElementsAttr reshaped_elements = dense_elements.reshape(values_type); rewriter.replaceOpWithNewOp(op, TypeAttr::get(result_type), diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 98582d03c553b9..9ed32a1b9a674e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -80,12 +80,16 @@ LogicalResult CreateTflFusableOpCustomOptions( size_t start_map = fbb.StartMap(); for (auto attr : attrs) { - if (auto float_attr = attr.second.dyn_cast_or_null()) { + if (auto float_attr = mlir::dyn_cast_or_null(attr.second)) { fbb.Float(attr.first.data(), float_attr.getValue().convertToFloat()); - } else if (auto int_attr = attr.second.dyn_cast_or_null()) { + } else if (auto int_attr = + mlir::dyn_cast_or_null(attr.second)) { fbb.Int(attr.first.data(), int_attr.getInt()); - } else if (auto bool_attr = attr.second.dyn_cast_or_null()) { + } else if (auto bool_attr = mlir::dyn_cast_or_null(attr.second)) { fbb.Bool(attr.first.data(), bool_attr.getValue()); + } else if (auto string_attr = + mlir::dyn_cast_or_null(attr.second)) { + fbb.String(attr.first.data(), string_attr.getValue().str()); } else { // TODO(b/201482289): support other data types. return failure(); @@ -180,7 +184,7 @@ LogicalResult CheckFusableLayerNormalizedLstmCellSimple( func::FuncOp lstm_func) { for (int i = 0; i < 5; ++i) { auto input = lstm_func.getArgument(i); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type) { lstm_func.emitWarning( "we cannot fuse this lstm func because all the inputs have not " @@ -195,7 +199,7 @@ LogicalResult CheckFusableLayerNormalizedLstmCellSimple( LogicalResult CheckFusableLstmCellSimple(func::FuncOp lstm_func) { for (int i = 0; i < 4; ++i) { auto input = lstm_func.getArgument(i); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type) { lstm_func.emitWarning( "we cannot fuse this lstm func because all the inputs have not " @@ -248,7 +252,7 @@ LogicalResult CheckFusableKerasLstm(func::FuncOp lstm_func, ModuleOp module) { // types. for (int i = 0; i < 6; ++i) { auto input = lstm_func.getArgument(i); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type) { lstm_func.emitWarning( "we cannot fuse this lstm func because all the inputs have not " @@ -366,7 +370,7 @@ void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes( for (auto attr_item : dict_attr) { // Push other attributes except the TFLFusableOp. if (attr_item.getName() == kTFLFusableOp && - attr_item.getValue().dyn_cast().getValue()) { + mlir::dyn_cast(attr_item.getValue()).getValue()) { tfl_fusable_op = true; } else { attributes.push_back({attr_item.getName(), attr_item.getValue()}); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index c625b329be6413..78951ae16397f6 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -15,6 +15,7 @@ limitations under the License. include "tensorflow/compiler/mlir/tensorflow/transforms/optimize.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" +include "tensorflow/compiler/mlir/lite/utils/utils.td" def FalseBoolAttr : AttrConstraint>; @@ -67,6 +68,28 @@ def ConvertMatmulWithTranspose : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt /*delta=*/(TF_ConstOp TFi32<-1>)), (TF_ConstOp TFi32<1>))), $b, ConstBoolAttrFalse, $bt, $grad_a, $grad_b)>; +// Pattern to fuse redundant tanspose op +def FoldDoubleTranspose : Pat< + (TF_TransposeOp + (TF_TransposeOp:$transpose_out1 $input, (Arith_ConstantOp:$permutation1 $p1)), + (Arith_ConstantOp:$permutation2 $p2)), + (TF_TransposeOp $input, + (Arith_ConstantOp (RemapPermutation $permutation1, $permutation2))), + [(HasOneUse $transpose_out1)]>; + +// Pattern to fuse trivial reshape op into transpose op +def FoldTrivialReshapeIntoTranspose : Pat< + (TF_ReshapeOp:$output + (TF_TransposeOp:$transpose_out1 $input, (Arith_ConstantOp:$permutation1 $p1)), $_), + (TF_TransposeOp:$transpose_op $input, + (Arith_ConstantOp + (RemapPermutation $permutation1, + (GetPermutationFromTrivialReshape $transpose_out1, $output)))), + [(IsReshapeEquivalentToTranspose $transpose_out1, $output), + (AnyStaticShapeTensor $input), + (AnyStaticShapeTensor $output), + (HasOneUse $transpose_out1)]>; + // Partially supported in TFLite, treated as passthrough IdentityOp def ConvertCheckNumerics : Pat<(TF_CheckNumericsOp $arg, $msg), (TF_IdentityOp $arg)>; def ConvertSnapshot : Pat<(TF_SnapshotOp $arg), (TF_IdentityOp $arg)>; @@ -136,6 +159,19 @@ def ReorderReshapeDequantQuantUsedByDepthwiseConv : (CanUpdateShapeWithAxis<3> $qtype, $old_value)], [], (addBenefit 10)>; +// The axis is set to 3, because this transpose is from the legalization of +// tf.depthwiseconvnative and the new channel axis is the last dimension. +def ReorderTransposeDequantQuantUsedByDepthwiseConv : + Pat<(TF_TransposeOp:$old_value + (TFL_DequantizeOp (TFL_QuantizeOp $input, $qtype)), $perm), + (TFL_DequantizeOp + (TFL_QuantizeOp + (TF_TransposeOp $input, $perm), + (UpdateShapeWithAxis<3> $qtype, $old_value))), + [(UsedBy<"DepthwiseConv2D"> $old_value), + (CanUpdateShapeWithAxis<3> $qtype, $old_value)], + [], (addBenefit 10)>; + // The Rank op produces result which is independent with the quantization // parameters of the input, so we can remove the quantization ops. def OptimizeAwayRankDequantQuant : diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index ce11ca73970136..9f76ad1f6e9098 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -153,8 +153,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { bool need_to_set_input_nodes_quantization_params = false; for (const BlockArgument arg : func.getArguments()) { - auto shaped = arg.getType().dyn_cast(); - if (shaped && shaped.getElementType().isa() && + auto shaped = mlir::dyn_cast(arg.getType()); + if (shaped && mlir::isa(shaped.getElementType()) && !has_quantize_op(arg)) { need_to_set_input_nodes_quantization_params = true; break; @@ -179,8 +179,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { auto add_quantize_op = [&](Location loc, Type input_type, Block* block, Block::iterator insertion_point, Value arg, int i) { - if (auto shaped = input_type.dyn_cast()) { - if (shaped.getElementType().isa()) { + if (auto shaped = mlir::dyn_cast(input_type)) { + if (mlir::isa(shaped.getElementType())) { // If there are existing quantize ops, they are from training and we // should respect them. if (has_quantize_op(arg)) { diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc index f0fd79ff207f39..0b823844aa4a58 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h" @@ -193,7 +194,7 @@ class PrepareDynamicRangeQuantizableOp continue; } - if (attr.dyn_cast().size() >= + if (mlir::dyn_cast(attr).size() >= quant_specs_.minimum_elements_for_weights) { continue; } @@ -205,7 +206,7 @@ class PrepareDynamicRangeQuantizableOp "supported. The operand ") << const_op->getName().getStringRef().str() << " at index " << qi << " was not quantized because it has " - << attr.dyn_cast().size() + << mlir::dyn_cast(attr).size() << " elements which is fewer than the " "`minimum_elements_for_weights` threshold of " << quant_specs_.minimum_elements_for_weights; @@ -233,7 +234,7 @@ class PrepareDynamicRangeQuantizableOp // Get types TensorType old_result_type = - op.getResult().getType().template dyn_cast(); + mlir::dyn_cast(op.getResult().getType()); FloatType quantized_type = FloatType::getF16(op.getContext()); ShapedType new_result_type = old_result_type.clone(quantized_type); @@ -287,27 +288,27 @@ class PrepareDynamicRangeQuantizableOp DenseFPElementsAttr attr; if (!matchPattern(op->getResult(0), m_Constant(&attr))) return false; - if (attr.dyn_cast().size() < + if (mlir::dyn_cast(attr).size() < quant_specs_.minimum_elements_for_weights) { op->emitRemark("Quantization is skipped for ") << quantize_op->getName().getStringRef().str() << " because it has " - << attr.dyn_cast().size() + << mlir::dyn_cast(attr).size() << " elements which is fewer than the threshold(" << quant_specs_.minimum_elements_for_weights << " elements)."; return false; } if (op_with_per_axis_support) { - quant_type = quant::GetUniformQuantizedPerAxisTypeForWeight( - attr, affine_user.GetQuantizationDimIndex(), - /*symmetric=*/true, bit_width, is_signed, - is_narrow_range, is_legacy_float) - .template dyn_cast(); + quant_type = mlir::dyn_cast( + quant::GetUniformQuantizedPerAxisTypeForWeight( + attr, affine_user.GetQuantizationDimIndex(), + /*symmetric=*/true, bit_width, is_signed, is_narrow_range, + is_legacy_float)); } else { - quant_type = quant::GetUniformQuantizedTypeForWeight( - attr, is_narrow_range && is_signed, bit_width, is_signed, - is_narrow_range, is_legacy_float) - .template dyn_cast(); + quant_type = mlir::dyn_cast( + quant::GetUniformQuantizedTypeForWeight( + attr, is_narrow_range && is_signed, bit_width, is_signed, + is_narrow_range, is_legacy_float)); } return insertQDQ(rewriter, op, quant_type, quant_op); } @@ -346,7 +347,7 @@ class PrepareDynamicRangeQuantizableOp bool getQuantizableOps(arith::ConstantOp op, QuantizationUnits& quantizable_ops) const { // Non-float tensors do not need quantization. - auto type = op.getType().dyn_cast(); + auto type = mlir::dyn_cast(op.getType()); if (!type || !type.getElementType().isF32()) return false; Value value = op.getResult(); @@ -420,7 +421,7 @@ class PrepareDynamicRangeQuantizableOp // Get types Type old_result_type = op.getResult().getType(); ShapedType new_result_type = - cast_op.getType().template dyn_cast(); + mlir::dyn_cast(cast_op.getType()); // Proceeds only if the casting is to float16 if (!new_result_type.getElementType().isF16()) continue; @@ -428,7 +429,7 @@ class PrepareDynamicRangeQuantizableOp // Cast values std::vector new_values; DenseFPElementsAttr value_attr = - op.getValue().cast(); + mlir::cast(op.getValue()); new_values.reserve(value_attr.getNumElements()); constexpr float kMaxFloat16Value = 65504.f; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h index e102c6bedd4328..061a8db4398321 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h @@ -36,16 +36,17 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/tools/optimize/operator_property.h" //===----------------------------------------------------------------------===// @@ -100,19 +101,18 @@ LogicalResult GetLstmProperty(LstmOp op, return failure(); } lstm_variant->use_projection = - !op.getProjectionWeights().getType().template isa(); + !mlir::isa(op.getProjectionWeights().getType()); lstm_variant->use_peephole = - !op.getCellToOutputWeights().getType().template isa(); + !mlir::isa(op.getCellToOutputWeights().getType()); lstm_variant->use_layer_norm = - !op.getForgetLayerNormCoefficients().getType().template isa(); + !mlir::isa(op.getForgetLayerNormCoefficients().getType()); *op_property = operator_property::GetOperatorProperty( *lstm_variant, activation_number_of_bits); // TODO(b/176258587) move this to operator_property.cc if this is needed in // other components, too. - bool use_cifg = - op.getInputToInputWeights().getType().template isa(); + bool use_cifg = mlir::isa(op.getInputToInputWeights().getType()); if (use_cifg) { const absl::flat_hash_set cifg_non_inputs = {1, 5, 9, 12, 20}; const int cifg_non_intermediate = 0; @@ -197,9 +197,9 @@ class PrepareLstmOutputScale : public OpRewritePattern { llvm::SmallVector min_max_values; for (auto& stats_op : stats_ops) { - auto values = stats_op.getLayerStats() - .dyn_cast() - .getValues(); + auto values = + mlir::dyn_cast(stats_op.getLayerStats()) + .getValues(); min_max_values.insert(min_max_values.end(), values.begin(), values.end()); } @@ -285,8 +285,8 @@ class ConvertOpStatsToQDQs : public OpRewritePattern { const operator_property::TensorProperty& tensor_property, PatternRewriter& rewriter) const { // Non-float tensors are neither weights nor require quantization. - auto type = const_op->getResult(0).getType().dyn_cast(); - if (!type || !type.getElementType().isa()) return success(); + auto type = mlir::dyn_cast(const_op->getResult(0).getType()); + if (!type || !mlir::isa(type.getElementType())) return success(); DenseFPElementsAttr attr; if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) { @@ -312,12 +312,12 @@ class ConvertOpStatsToQDQs : public OpRewritePattern { rewriter.getIntegerType(16), attr.getType().getElementType(), scale, /*zeroPoint=*/0, llvm::minIntN(10), -llvm::minIntN(10)); } else { - quant_type = quant::GetUniformQuantizedTypeForWeight( - attr, /*symmetric=*/true, - /*num_bits=*/tensor_property.number_of_bits, - /*is_signed=*/true, - /*narrow_range=*/true, quant_specs_.legacy_float_scale) - .template dyn_cast(); + quant_type = mlir::dyn_cast( + quant::GetUniformQuantizedTypeForWeight( + attr, /*symmetric=*/true, + /*num_bits=*/tensor_property.number_of_bits, + /*is_signed=*/true, + /*narrow_range=*/true, quant_specs_.legacy_float_scale)); } if (!quant_type) { const_op->emitError("Failed to get quantized type"); @@ -346,7 +346,7 @@ class ConvertOpStatsToQDQs : public OpRewritePattern { << "] is a state tensor, but has more than one use."; return failure(); } - auto stats = stats_op.getLayerStats().dyn_cast(); + auto stats = mlir::dyn_cast(stats_op.getLayerStats()); if (!stats || stats.getNumElements() != 2) { stats_op.emitError("Stats should have 2 values."); return failure(); @@ -454,7 +454,7 @@ class ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs { return failure(); } auto calibrated_type = - quant_type.template dyn_cast(); + mlir::dyn_cast(quant_type); if (!calibrated_type) { int num_storage_bits = quant_type.getStorageTypeIntegralWidth(); if (tensor_property.number_of_bits != num_storage_bits) { @@ -474,9 +474,9 @@ class ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs { /*narrowRange=*/false, calibrated_type.getExpressedType(), /*isSigned=*/this->quant_specs_.IsSignedInferenceType()); if (this->quant_specs_.legacy_float_scale) { - qtype = quant::DownCastScale(qtype, calibrated_type.getMin(), - calibrated_type.getMax(), op.getLoc()) - .template cast(); + qtype = mlir::cast( + quant::DownCastScale(qtype, calibrated_type.getMin(), + calibrated_type.getMax(), op.getLoc())); } } else if (tensor_property.number_of_bits == 16) { double max = std::max(std::abs(calibrated_type.getMin()), @@ -508,9 +508,9 @@ inline quant::AccumulatorScaleFunc GetUniformQuantizedTypeForBiasWithScale( return [=](const std::vector& quant_params, const int adjusted_quant_dim, const bool legacy_float_scale) -> quant::QuantParams { - if (auto qtype = quant::GetUniformQuantizedTypeForBias( - quant_params, legacy_float_scale, adjusted_quant_dim) - .dyn_cast_or_null()) { + if (auto qtype = mlir::dyn_cast_or_null( + quant::GetUniformQuantizedTypeForBias( + quant_params, legacy_float_scale, adjusted_quant_dim))) { return quant::UniformQuantizedType::get( qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), qtype.getScale() * scale, qtype.getZeroPoint(), @@ -540,14 +540,14 @@ std::unique_ptr GetLstmOpQuantSpec(LstmOp op) { tensor_property.derived_scale.intermediate_tensors) { auto quant_type = GetIntermediateElementType(op, tensor_index); if (!quant_type || - !quant_type.template isa()) { + !mlir::isa(quant_type)) { op->emitError() << "While processing derived scale, intermediate " << intermediate_attributes[tensor_index] << " is not quantized."; return nullptr; } - scale *= quant_type.template dyn_cast() - .getScale(); + scale *= + mlir::dyn_cast(quant_type).getScale(); } for (float factor : tensor_property.derived_scale.factors) { scale *= factor; @@ -590,7 +590,8 @@ class PropagateTransposedPerAxisQuantDim auto q_op = dyn_cast_or_null( dq_op.getOperand().getDefiningOp()); if (!q_op) return failure(); - auto qtype = dq_op.getArg().getType().cast().getElementType(); + auto qtype = + mlir::cast(dq_op.getArg().getType()).getElementType(); auto aqtype = dyn_cast_or_null(qtype); if (!aqtype) return failure(); @@ -599,8 +600,8 @@ class PropagateTransposedPerAxisQuantDim auto next_op = *transpose_op.getResult().getUsers().begin(); if (dyn_cast_or_null(next_op)) return failure(); - auto input_type = transpose_op.getInput().getType().cast(); - auto perm_type = transpose_op.getPerm().getType().cast(); + auto input_type = mlir::cast(transpose_op.getInput().getType()); + auto perm_type = mlir::cast(transpose_op.getPerm().getType()); if (input_type.hasStaticShape() && perm_type.hasStaticShape()) { if (perm_type.getNumElements() != input_type.getRank()) { return transpose_op.emitOpError( diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 9f0a7fbafff450..b0b6fc8ac7f2d8 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -67,6 +67,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h" #include "tensorflow/compiler/mlir/lite/utils/size_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" @@ -89,7 +90,7 @@ namespace { // Preconditions: The given value must have a ShapedType. static Value CreateTFCastOpI32(OpBuilder *builder, Location loc, Value x, BoolAttr truncate) { - auto x_type = x.getType().dyn_cast_or_null(); + auto x_type = mlir::dyn_cast_or_null(x.getType()); if (!x_type) llvm_unreachable("unsupported type"); Type type = x_type.clone(builder->getI32Type()); return builder->create(loc, type, x, truncate); @@ -200,14 +201,14 @@ class ConvertTFConvOp : public RewritePattern { // that we can extract info from the shape (e.g., for constructing bias // tensor, for setting depth_multiplier attribute, etc.). auto filter = tf_op.getFilter(); - auto filter_type = filter.getType().template dyn_cast(); + auto filter_type = mlir::dyn_cast(filter.getType()); if (!filter_type || filter_type.getRank() != 4 || !filter_type.hasStaticShape()) return failure(); Value input = tf_op.getInput(); RankedTensorType input_type = - input.getType().template dyn_cast(); + mlir::dyn_cast(input.getType()); // Only rank size four input will be only available by the tf.Conv2D // operator verification. if (!input_type || input_type.isDynamicDim(3)) { @@ -244,7 +245,7 @@ class ConvertTFConvOp : public RewritePattern { op->getAttrOfType("explicit_paddings").getValue(); auto get_int = [](Attribute attr) { - return attr.template cast().getInt(); + return mlir::cast(attr).getInt(); }; SmallVector padding_values(padding_attr_array.size()); @@ -324,7 +325,7 @@ class ConvertTFConv2D : public ConvertTFConvOp { auto perm_op = rewriter.create(loc, perm_type, perm_attr); // Create tensor type for the transpose result. - auto filter_type = filter.getType().cast(); + auto filter_type = mlir::cast(filter.getType()); auto result_shape = llvm::to_vector<4>(llvm::map_range(perm, [filter_type](int64_t dim) { return filter_type.getDimSize(dim); @@ -361,7 +362,8 @@ class ConvertTFDepthwiseConv2dNative // have a corresponding 'depth_multiplier' attribute; the multiplier is the // fourth dimension in the 4-D filter tensor. We query the multiplier from // tf.DepthwiseConv2dNative and set it as the attribute value accordingly. - auto multiplier = filter.getType().cast().getDimSize(3); + auto multiplier = + mlir::cast(filter.getType()).getDimSize(3); filter = legalizeFilter(rewriter, loc, filter); return rewriter.create( @@ -385,7 +387,7 @@ class ConvertTFDepthwiseConv2dNative /// RankedTensorType. Value legalizeFilter(PatternRewriter &rewriter, Location loc, Value filter) const { - auto filter_type = filter.getType().cast(); + auto filter_type = mlir::cast(filter.getType()); auto filterShape = filter_type.getShape(); SmallVector result_shape = {1, filterShape[0], filterShape[1], filterShape[2] * filterShape[3]}; @@ -443,7 +445,7 @@ struct ConvertTFStridedSlice : public RewritePattern { // Insert a new reshape op. Value original_input = strided_slice_op.getInput(); RankedTensorType original_input_type = - original_input.getType().dyn_cast(); + mlir::dyn_cast(original_input.getType()); if (!original_input_type) { return failure(); } @@ -522,7 +524,8 @@ struct ConvertTFStridedSlice : public RewritePattern { DenseIntElementsAttr begin_dense_elem_attr; Value begin = strided_slice_op.getBegin(); - auto begin_ranked_attr_type = begin.getType().dyn_cast(); + auto begin_ranked_attr_type = + mlir::dyn_cast(begin.getType()); if (!begin_ranked_attr_type || !matchPattern(begin, m_Constant(&begin_dense_elem_attr))) { return failure(); @@ -530,7 +533,7 @@ struct ConvertTFStridedSlice : public RewritePattern { DenseIntElementsAttr end_dense_elem_attr; Value end = strided_slice_op.getEnd(); - auto end_ranked_attr_type = end.getType().dyn_cast(); + auto end_ranked_attr_type = mlir::dyn_cast(end.getType()); if (!end_ranked_attr_type || !matchPattern(end, m_Constant(&end_dense_elem_attr))) { return failure(); @@ -539,14 +542,15 @@ struct ConvertTFStridedSlice : public RewritePattern { DenseIntElementsAttr stride_dense_elem_attr; Value stride = strided_slice_op.getStrides(); auto stride_ranked_attr_type = - stride.getType().dyn_cast(); + mlir::dyn_cast(stride.getType()); if (!stride_ranked_attr_type || !matchPattern(stride, m_Constant(&stride_dense_elem_attr))) { return failure(); } Value input = strided_slice_op.getInput(); - RankedTensorType input_type = input.getType().dyn_cast(); + RankedTensorType input_type = + mlir::dyn_cast(input.getType()); if (!input_type) { return failure(); } @@ -554,7 +558,7 @@ struct ConvertTFStridedSlice : public RewritePattern { const int input_size = input_shape.size(); - RankedTensorType begin_type = begin.getType().cast(); + RankedTensorType begin_type = mlir::cast(begin.getType()); const ArrayRef begin_shape = begin_type.getShape(); const int begin_dim = begin_shape.size(); @@ -688,7 +692,7 @@ struct ConvertTFStridedSlice : public RewritePattern { } auto ranked_input_type = - strided_slice_op.getInput().getType().dyn_cast(); + mlir::dyn_cast(strided_slice_op.getInput().getType()); if (!ranked_input_type) { return failure(); } @@ -697,10 +701,11 @@ struct ConvertTFStridedSlice : public RewritePattern { auto end_attr = strided_slice_op.getEnd(); auto strides_attr = strided_slice_op.getStrides(); - auto begin_attr_type = begin_attr.getType().dyn_cast(); - auto end_attr_type = end_attr.getType().dyn_cast(); + auto begin_attr_type = + mlir::dyn_cast(begin_attr.getType()); + auto end_attr_type = mlir::dyn_cast(end_attr.getType()); auto strides_attr_type = - strides_attr.getType().dyn_cast(); + mlir::dyn_cast(strides_attr.getType()); DenseIntElementsAttr begin_elem_attr; DenseIntElementsAttr end_elem_attr; @@ -899,8 +904,8 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern { if (!epsilon) epsilon = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0001f); - if (!(((epsilon.isa<::mlir::FloatAttr>())) && - ((epsilon.cast<::mlir::FloatAttr>().getType().isF32())))) { + if (!(((mlir::isa<::mlir::FloatAttr>(epsilon))) && + ((mlir::cast<::mlir::FloatAttr>(epsilon).getType().isF32())))) { return rewriter.notifyMatchFailure( fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { diag << "op 'tf.FusedBatchNormV3' attribute 'epsilon' failed to " @@ -963,7 +968,7 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern { int64_t last_dim = ShapedType::kDynamic; { auto is_last_dim_compatible = [](const Value &v, int64_t &last_dim) { - auto v_type = v.getType().dyn_cast_or_null(); + auto v_type = mlir::dyn_cast_or_null(v.getType()); if (!v_type) return true; int64_t v_last_dim = v_type.getDimSize(v_type.getRank() - 1); if (v_last_dim == ShapedType::kDynamic) return true; @@ -1007,9 +1012,8 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern { // For training, mean and variance is calculated from input values. if (is_training.getValue()) { - auto input_type = fused_batch_norm_op.getX() - .getType() - .dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null( + fused_batch_norm_op.getX().getType()); if (!input_type || input_type.getRank() != 4) { return rewriter.notifyMatchFailure( fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { @@ -1383,14 +1387,14 @@ struct ConvertRfftToRfft2d : public RewritePattern { auto rfft_op = dyn_cast(op); auto input = rfft_op.getInput(); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type) return failure(); auto fft_len = rfft_op.getFftLength(); - auto fft_len_type = fft_len.getType().dyn_cast_or_null(); + auto fft_len_type = mlir::dyn_cast_or_null(fft_len.getType()); if (!fft_len_type) return failure(); auto output_type = - rfft_op.getResult().getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(rfft_op.getResult().getType()); if (!output_type) return failure(); // Expanded inputs. diff --git a/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc b/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc index 363c30ab0b818c..7a8b35e4be7cde 100644 --- a/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc +++ b/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -280,7 +281,7 @@ class CommuteTransposeWithEwiseOps : public RewritePattern { } auto other_input_type = - cst_arg->getResult(0).getType().cast(); + mlir::cast(cst_arg->getResult(0).getType()); Operation *tposed_const; if (other_input_type.getNumElements() == 1) { diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index 9c38d09ab0c2bd..e41f98af795347 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -249,6 +249,9 @@ void QuantizePass::runOnOperation() { quant::CustomOpUpdateOptions::kWeightOnly, quant_specs.custom_map); } + if (enable_float16_quantization_) { + quant_specs.inference_type = tensorflow::DT_HALF; + } const quant::QuantPassSpec quant_params = { {quant_specs.verify_numeric, error_tolerance_, diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc b/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc index 0d9db051ef27ff..96412f20633f6a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -169,7 +170,7 @@ void QuantizeVariablesPass::QuantizeVariable( for (VarHandleOp var_handle_op : var_handle_ops) { builder.setInsertionPoint(var_handle_op); auto output_type = UnrankedTensorType::get(TF::ResourceType::get( - {ref_qtype.cast()}, builder.getContext())); + {mlir::cast(ref_qtype)}, builder.getContext())); auto new_var_handle_op = builder.create( var_handle_op.getLoc(), output_type, var_handle_op.getContainer(), var_handle_op.getSharedName()); diff --git a/tensorflow/compiler/mlir/lite/transforms/reduce_type_precision.cc b/tensorflow/compiler/mlir/lite/transforms/reduce_type_precision.cc index bee14272020446..659c5aceb39c04 100644 --- a/tensorflow/compiler/mlir/lite/transforms/reduce_type_precision.cc +++ b/tensorflow/compiler/mlir/lite/transforms/reduce_type_precision.cc @@ -62,12 +62,12 @@ class CheckRangeAndConvertI8ToI4 : public OpRewritePattern { LogicalResult matchAndRewrite(arith::ConstantOp op, PatternRewriter &rewriter) const override { - auto const_type = op.getType().dyn_cast(); + auto const_type = mlir::dyn_cast(op.getType()); if (!const_type || !const_type.getElementType().isSignlessInteger(8)) { return failure(); } - auto attr = op.getValue().cast(); + auto attr = mlir::cast(op.getValue()); for (mlir::APInt v : attr.getValues()) { auto v_int = static_cast(*(v.getRawData())); if (v_int > 7 || v_int < -8) { @@ -79,7 +79,7 @@ class CheckRangeAndConvertI8ToI4 : public OpRewritePattern { auto shaped_type = mlir::RankedTensorType::get(const_type.getShape(), builder.getI4Type()); auto newAttr = DenseElementsAttr::getFromRawBuffer( - shaped_type, op.getValue().cast().getRawData()); + shaped_type, mlir::cast(op.getValue()).getRawData()); rewriter.replaceOpWithNewOp(op, newAttr); return success(); @@ -92,8 +92,8 @@ class SanitizeGatherOpOutputToI4 : public OpRewritePattern { LogicalResult matchAndRewrite(TFL::GatherOp op, PatternRewriter &rewriter) const override { - auto const_type = op.getOperand(0).getType().dyn_cast(); - auto result_type = op.getResult().getType().dyn_cast(); + auto const_type = mlir::dyn_cast(op.getOperand(0).getType()); + auto result_type = mlir::dyn_cast(op.getResult().getType()); if (!const_type || !const_type.getElementType().isSignlessInteger(4) || !result_type || !result_type.getElementType().isSignlessInteger(8)) { return failure(); @@ -109,7 +109,8 @@ class SanitizeGatherOpOutputToI4 : public OpRewritePattern { auto new_gather_op = rewriter.create( op.getLoc(), /*result=*/ - op.getResult().getType().cast().clone(builder.getI4Type()), + mlir::cast(op.getResult().getType()) + .clone(builder.getI4Type()), /*operand=*/op.getOperands(), op->getAttrs()); rewriter.replaceAllUsesWith(op.getResult(), new_gather_op.getResult()); diff --git a/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc b/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc index c8999216c8054b..ab03af3a4c062a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc +++ b/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc @@ -104,7 +104,7 @@ void FindProducers(Value start_node, std::vector &neighbors) { while (!queue.empty()) { auto node = queue.back(); queue.pop_back(); - if (auto arg = node.dyn_cast_or_null()) { + if (auto arg = mlir::dyn_cast_or_null(node)) { neighbors.push_back(arg.getArgNumber()); continue; } @@ -149,7 +149,7 @@ bool AllOperationSafe(Block &block) { // Fact: if every op's operands are defined in the same block as op, // then no operation has implicit arugments (constant doesn't count). for (auto operand : op->getOperands()) { - if (operand.dyn_cast_or_null()) continue; + if (mlir::dyn_cast_or_null(operand)) continue; auto operand_op = operand.getDefiningOp(); if (IsConstant(operand_op)) continue; if (operand_op->getBlock() != op->getBlock()) { diff --git a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc index a787da584ea8be..4c555a8d0f6e3b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc +++ b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc @@ -85,23 +85,36 @@ LogicalResult DuplicateValueIfNeeded(Operation* op, Value operand = op->getOperand(index); auto inserted_value = values->insert(operand).second; if (inserted_value) continue; - // We can only clone the constant op at this point. - // Since all ops have been legalized to tflite ops, so we only care about - // ConstOp or QConstOp or mlir constant op/ + // We can only clone the constant op or const->dequantize combo. The latter + // case is useful for float16 quantization. Since all ops have been + // legalized to tflite ops, so we only care about ConstOp or QConstOp or + // mlir constant op. Operation* input_op = operand.getDefiningOp(); if (input_op == nullptr) return failure(); Attribute attr; - if (!matchPattern(input_op, m_Constant(&attr))) { + if (matchPattern(input_op, m_Constant(&attr))) { + // Constant case. + builder->setInsertionPoint(op); + Operation* duplicated_input_op = builder->clone(*input_op); + + // Rewire the inputs. + op->setOperand(index, duplicated_input_op->getResult(0)); + } else if (auto dq = dyn_cast(input_op); + dq && matchPattern(dq.getInput(), m_Constant(&attr))) { + // Constant -> Dequantize case. + builder->setInsertionPoint(op); + Operation* duplicated_input_op = + builder->clone(*dq.getInput().getDefiningOp()); + Operation* duplicated_dq_op = builder->clone(*dq); + // Rewire the inputs. + duplicated_dq_op->setOperand(0, duplicated_input_op->getResult(0)); + op->setOperand(index, duplicated_dq_op->getResult(0)); + } else { op->emitError() << "We cannot duplicate the value since it's not constant.\n"; return failure(); } - builder->setInsertionPoint(op); - Operation* duplicated_input_op = builder->clone(*input_op); - - // Rewire the inputs. - op->setOperand(index, duplicated_input_op->getResult(0)); } return success(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/unfold_large_splat_constant.cc b/tensorflow/compiler/mlir/lite/transforms/unfold_large_splat_constant.cc index 1def97523cd668..2669159b0206bb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/unfold_large_splat_constant.cc +++ b/tensorflow/compiler/mlir/lite/transforms/unfold_large_splat_constant.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -62,7 +63,7 @@ class UnfoldLargeSplatConstantPass void MaybeUnfoldLargeSplatConstant(mlir::OpBuilder* op_builder, mlir::arith::ConstantOp const_op) const { auto splat_elements_attr = - const_op.getValue().dyn_cast(); + mlir::dyn_cast(const_op.getValue()); if (!splat_elements_attr) { return; } diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index a3c3ece3dc94a1..013abb6ec0ea80 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -92,13 +93,13 @@ bool IsCompatibleTypeWithTFLCastOp(Type type) { return true; // Complex> is allowed. - if (elemType.isa() && - elemType.cast().getElementType().isF32()) + if (mlir::isa(elemType) && + mlir::cast(elemType).getElementType().isF32()) return true; // QUINT8 and UI8 are allowed. - if (elemType.isa() || - (elemType.isInteger(8) && elemType.cast().isUnsigned())) + if (mlir::isa(elemType) || + (elemType.isInteger(8) && mlir::cast(elemType).isUnsigned())) return true; return false; diff --git a/tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h b/tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h index ac170f33d9ba85..c851d73b03290d 100644 --- a/tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h +++ b/tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h @@ -17,6 +17,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TFL { @@ -27,7 +28,7 @@ class ArithmeticCountUtilHelper { static bool GetFirstOutputCount(mlir::Operation* op, int64_t* count) { auto output = op->getResult(0); auto output_type = - output.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(output.getType()); if (!output_type || !output_type.hasStaticShape()) return false; *count = output_type.getNumElements(); @@ -38,7 +39,7 @@ class ArithmeticCountUtilHelper { int64_t total_count = 0; for (auto input : op->getOperands()) { auto input_type = - input.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(input.getType()); if (!input_type || !input_type.hasStaticShape()) { return false; } @@ -54,12 +55,12 @@ class ArithmeticCountUtilHelper { int64_t* count) { auto weight = op->getOperand(1); auto weight_type = - weight.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(weight.getType()); if (weight_type == nullptr || !weight_type.hasStaticShape()) return false; auto output = op->getResult(0); auto output_type = - output.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(output.getType()); if (output_type == nullptr || !output_type.hasStaticShape()) return false; int64_t cols = 1; @@ -73,7 +74,7 @@ class ArithmeticCountUtilHelper { auto bias = op->getOperand(2); if (bias) { auto bias_type = - bias.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(bias.getType()); if (bias_type && bias_type.hasStaticShape()) { *count += output_type.getNumElements(); } diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc index 20336080cc20d6..1629000ff181df 100644 --- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc @@ -15,23 +15,24 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TFL { FloatAttr ExtractSingleElementAsFloat(ElementsAttr attr) { if (attr.getShapedType().getNumElements() != 1 || - !attr.getShapedType().getElementType().isa()) { + !mlir::isa(attr.getShapedType().getElementType())) { return {}; } return attr.getSplatValue(); } FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr) { - if (auto m = attr.dyn_cast_or_null()) { + if (auto m = mlir::dyn_cast_or_null(attr)) { return ExtractSingleElementAsFloat(m); } else { - return attr.dyn_cast_or_null(); + return mlir::dyn_cast_or_null(attr); } } diff --git a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc index 96d75cca30a48d..41eed865496a01 100644 --- a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc @@ -36,11 +36,11 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/low_bit_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" -#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/string_util.h" #include "tsl/platform/statusor.h" @@ -131,11 +131,11 @@ StatusOr GetQuantizedType(const TensorT& tensor, Builder builder, if (!storage_type) { const mlir::Type raw_elem_type = ConvertElementType(tensor.type, builder); - if (!raw_elem_type.isa()) { + if (!mlir::isa(raw_elem_type)) { return absl::InvalidArgumentError( "Quantized tensors must be stored as integers"); } - storage_type = raw_elem_type.cast(); + storage_type = mlir::cast(raw_elem_type); } // TFlite uses narrow-range [u]int8 for constant buffers of quantized weights. @@ -254,11 +254,11 @@ mlir::ElementsAttr GetSplat(RankedTensorType type, int unique_index, return DenseElementsAttr::get( type, builder.getIntegerAttr(element_ty, unique_index)); - if (element_ty.isa()) + if (mlir::isa(element_ty)) return DenseElementsAttr::get( type, builder.getFloatAttr(element_ty, unique_index)); - if (auto qtype = element_ty.dyn_cast()) { + if (auto qtype = mlir::dyn_cast(element_ty)) { mlir::RankedTensorType new_type = tensorflow::GetTypeFromTFTensorShape( type.getShape(), qtype.getStorageType()); return DenseElementsAttr::get( @@ -272,9 +272,10 @@ StatusOr ConvertIntBuffer( bool truncate) { mlir::Type elem_type = shaped_type.getElementType(); unsigned bit_width; - if (auto itype = elem_type.dyn_cast()) { + if (auto itype = mlir::dyn_cast(elem_type)) { bit_width = itype.getWidth(); - } else if (auto qtype = elem_type.dyn_cast()) { + } else if (auto qtype = + mlir::dyn_cast(elem_type)) { bit_width = qtype.getStorageTypeIntegralWidth(); shaped_type = tensorflow::GetTypeFromTFTensorShape(shaped_type.getShape(), qtype.getStorageType()); diff --git a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h index 52bcbbef72aba6..d9618517a5dc96 100644 --- a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h @@ -25,9 +25,9 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/lite/schema/schema_generated.h" namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc index 8bf3b4f0106604..6a4dbf3e505ba6 100644 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc @@ -56,7 +56,8 @@ absl::StatusOr CreateTypedAttr(ShapedType shaped_type, int value) { } else if (element_type.isF32()) { return DenseElementsAttr::get(shaped_type, static_cast(value)); - } else if (auto complex_type = element_type.dyn_cast()) { + } else if (auto complex_type = + mlir::dyn_cast(element_type)) { auto etype = complex_type.getElementType(); if (etype.isF32()) { tensorflow::TensorProto repr; @@ -77,7 +78,7 @@ absl::StatusOr CreateTypedAttr(ShapedType shaped_type, int value) { return tensorflow::Status(absl::StatusCode::kInvalidArgument, "Unsupported type"); } - } else if (auto itype = element_type.dyn_cast()) { + } else if (auto itype = mlir::dyn_cast(element_type)) { if (element_type.isSignedInteger()) { switch (itype.getWidth()) { case 8: diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index f2e659b9aea9ce..c7f922de39ad81 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -18,12 +18,13 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "xla/statusor.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { @@ -40,16 +41,16 @@ tflite::TensorType ConvertTypeToTensorType(mlir::Type type) { return tflite::TensorType_FLOAT32; } else if (type.isF64()) { return tflite::TensorType_FLOAT64; - } else if (type.isa()) { + } else if (mlir::isa(type)) { return tflite::TensorType_STRING; - } else if (auto complex_type = type.dyn_cast()) { + } else if (auto complex_type = mlir::dyn_cast(type)) { if (complex_type.getElementType().isF32()) { return tflite::TensorType_COMPLEX64; } else if (complex_type.getElementType().isF64()) { return tflite::TensorType_COMPLEX128; } llvm_unreachable("invalid complex Type in conversion"); - } else if (auto itype = type.dyn_cast()) { + } else if (auto itype = mlir::dyn_cast(type)) { switch (itype.getWidth()) { case 1: return tflite::TensorType_BOOL; @@ -209,7 +210,7 @@ absl::StatusOr TfTypeToTflType(tensorflow::DataType type) { mlir::Type GetShapeStrippedType(mlir::TypeAttr type_attr) { auto type = type_attr.getValue(); - auto shaped_type = type.dyn_cast(); + auto shaped_type = mlir::dyn_cast(type); if (shaped_type) { return shaped_type.getElementType(); } else { diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.h b/tensorflow/compiler/mlir/lite/utils/convert_type.h index ce26591d52b34a..85631dbe258f8e 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.h +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.h @@ -19,9 +19,9 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "xla/statusor.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/lite/schema/schema_generated.h" namespace mlir { class Builder; diff --git a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h index 77b047f68c6bf2..d1dcf8c304b0a9 100644 --- a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h @@ -123,7 +123,7 @@ class InsertTFLQuantOpsAfterTFFakeQuantOp { int quant_dim = -1; if (PerAxis) { // This is a special case that the quant_dim is the last dimensions. - quant_dim = res.getType().template cast().getRank() - 1; + quant_dim = mlir::cast(res.getType()).getRank() - 1; } // Use the min/max from the operands and the num_bits and narrow_range // attribute to create the quantization parameter for the new quantize op. diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index 0a563238635d20..bada49a68a9e55 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -127,7 +127,7 @@ Value Reverse(OpBuilder* builder, Value value_to_reverse, int axis, } ArrayRef GetRankedTensorShape(Value value) { - return value.getType().cast().getShape(); + return mlir::cast(value.getType()).getShape(); } Value SliceRankedTensor(OpBuilder* builder, Value input, @@ -159,7 +159,7 @@ Value SliceRankedTensor(OpBuilder* builder, Value input, location, RankedTensorType::get( size_values, - input.getType().cast().getElementType()), + mlir::cast(input.getType()).getElementType()), input, slice_i2c_begin, slice_i2c_size); } @@ -170,7 +170,8 @@ Value CreateStridedSliceOp(mlir::Location loc, ArrayRef output_shape, int64_t ellipsis_mask, int64_t new_axis_mask, int64_t shrink_axis_mask, OpBuilder* builder) { auto output_type = RankedTensorType::get( - output_shape, input.getType().cast().getElementType()); + output_shape, + mlir::cast(input.getType()).getElementType()); auto begin_tensor = CreateI32DenseConst(builder, begin, loc); auto end_tensor = CreateI32DenseConst(builder, end, loc); auto strides_tensor = CreateI32DenseConst(builder, strides, loc); @@ -387,7 +388,8 @@ void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() { SmallVector output_shape{1, tensorflow::kTFDynamicSize}; auto input_types = fused_func_op_.getFunctionType().getInputs(); auto output_type = tensorflow::GetTypeFromTFTensorShape( - output_shape, input_.getType().cast().getElementType()); + output_shape, + mlir::cast(input_.getType()).getElementType()); fused_func_op_.setType(mlir::FunctionType::get(fused_func_op_.getContext(), input_types, output_type)); } @@ -410,7 +412,8 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() { // Create the fused LSTM op. SmallVector output_shape = {1, n_output_}; auto result_type = mlir::RankedTensorType::get( - output_shape, input_.getType().cast().getElementType()); + output_shape, + mlir::cast(input_.getType()).getElementType()); lstm_ = builder_.create( fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_, input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_, @@ -436,7 +439,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() { SmallVector func_output_shape = {1, tensorflow::kTFDynamicSize}; auto func_result_type = tensorflow::GetTypeFromTFTensorShape( func_output_shape, - input_.getType().cast().getElementType()); + mlir::cast(input_.getType()).getElementType()); auto tensor_cast = builder_.create( fused_func_op_.getLoc(), func_result_type, lstm_.getResult()); @@ -491,7 +494,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() { bias_ = fused_func_op_.getArgument(2); weight_ = fused_func_op_.getArgument(1); - weight_type_ = weight_.getType().cast(); + weight_type_ = mlir::cast(weight_.getType()); if (weight_type_.getRank() != 2) { return fused_func_op_.emitError() << "The weight tensor was not of rank 2"; @@ -505,7 +508,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() { n_cell_ = weight_type_.getDimSize(1) / num_gates_; projection_ = fused_func_op_.getArgument(3); - projection_type_ = projection_.getType().cast(); + projection_type_ = mlir::cast(projection_.getType()); if (projection_type_.getRank() != 2) { n_output_ = n_cell_; } else { @@ -532,7 +535,8 @@ LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() { } layer_norm_scale_ = fused_func_op_.getArgument(4); - layer_norm_scale_type_ = layer_norm_scale_.getType().cast(); + layer_norm_scale_type_ = + mlir::cast(layer_norm_scale_.getType()); if (layer_norm_scale_type_.getRank() != 1) { return fused_func_op_.emitError() << "The layer_norm_scale tensor was not of rank 1"; @@ -607,7 +611,7 @@ TF::ReshapeOp CreateFlattenOP(const Value& input, Location loc, LogicalResult CreateEqualSizeSplitVOp(Value input, int axis, int splits, Location loc, OpBuilder* builder, Operation** result) { - auto input_type = input.getType().cast(); + auto input_type = mlir::cast(input.getType()); SmallVector output_shape; int size_of_splits; if (input_type.getRank() < axis || axis < 0) return failure(); @@ -666,7 +670,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, if (time_major_attr == nullptr) return failure(); bool time_majored = time_major_attr.getValue(); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type) { func_op.emitError() << "Input type is not a ranked tensor type"; return failure(); @@ -692,7 +696,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, // Setup correct weights. RankedTensorType weight_type = - weight_kernel.getType().cast(); + mlir::cast(weight_kernel.getType()); if (weight_type.getRank() != 2) return func_op.emitError() << "The weight should be rank of 2"; @@ -700,7 +704,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, Transpose2D(builder, weight_kernel, weight_type, func_op.getLoc()); RankedTensorType recurrent_kernel_type = - recurrent_kernel.getType().cast(); + mlir::cast(recurrent_kernel.getType()); const int64_t n_output = recurrent_kernel_type.getDimSize(0); Value transpose_recurrent_kernel = Transpose2D( @@ -726,28 +730,28 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, // IndyLSTMs are a LSTM variant with diagonal recurrent weight // matrices. For optimization purposes these are provided as vectors. Value recurrent_to_input_weights = - indy ? CreateFlattenOP(recurrent_weights_array->getResult(0), - func_op.getLoc(), builder) - .getResult() - .cast() + indy ? mlir::cast( + CreateFlattenOP(recurrent_weights_array->getResult(0), + func_op.getLoc(), builder) + .getResult()) : recurrent_weights_array->getResult(0); Value recurrent_to_forget_weights = - indy ? CreateFlattenOP(recurrent_weights_array->getResult(1), - func_op.getLoc(), builder) - .getResult() - .cast() + indy ? mlir::cast( + CreateFlattenOP(recurrent_weights_array->getResult(1), + func_op.getLoc(), builder) + .getResult()) : recurrent_weights_array->getResult(1); Value recurrent_to_cell_weights = - indy ? CreateFlattenOP(recurrent_weights_array->getResult(2), - func_op.getLoc(), builder) - .getResult() - .cast() + indy ? mlir::cast( + CreateFlattenOP(recurrent_weights_array->getResult(2), + func_op.getLoc(), builder) + .getResult()) : recurrent_weights_array->getResult(2); Value recurrent_to_output_weights = - indy ? CreateFlattenOP(recurrent_weights_array->getResult(3), - func_op.getLoc(), builder) - .getResult() - .cast() + indy ? mlir::cast( + CreateFlattenOP(recurrent_weights_array->getResult(3), + func_op.getLoc(), builder) + .getResult()) : recurrent_weights_array->getResult(3); // Splits the bias into 4: @@ -765,7 +769,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, } auto result_type = mlir::RankedTensorType::get( output_shape, - final_inputs.getType().cast().getElementType()); + mlir::cast(final_inputs.getType()).getElementType()); Value none = CreateNoneValue(builder, func_op.getLoc()); auto lstm = builder->create( @@ -866,7 +870,8 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, // All the rest: states, device. for (int i = 2; i < 5; ++i) { - auto result_type = func_op.getResultTypes()[i].dyn_cast(); + auto result_type = + mlir::dyn_cast(func_op.getResultTypes()[i]); outputs.push_back(CreatTfF32ConstOp(builder, result_type.getShape(), 0.0f, func_op.getLoc())); output_types.push_back(result_type); diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc index 342bbb5c7fe382..7fe7ae8404137c 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -134,22 +134,18 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) { auto transpose_op = fused_lstm_func_.getBody().front().begin(); transpose_op++; - EXPECT_EQ( - transpose_op->getOperand(0).getType().cast().getDimSize( - 0), - 3); - EXPECT_EQ( - transpose_op->getOperand(0).getType().cast().getDimSize( - 1), - 12); - EXPECT_EQ( - transpose_op->getResult(0).getType().cast().getDimSize( - 0), - 12); - EXPECT_EQ( - transpose_op->getResult(0).getType().cast().getDimSize( - 1), - 3); + EXPECT_EQ(mlir::cast(transpose_op->getOperand(0).getType()) + .getDimSize(0), + 3); + EXPECT_EQ(mlir::cast(transpose_op->getOperand(0).getType()) + .getDimSize(1), + 12); + EXPECT_EQ(mlir::cast(transpose_op->getResult(0).getType()) + .getDimSize(0), + 12); + EXPECT_EQ(mlir::cast(transpose_op->getResult(0).getType()) + .getDimSize(1), + 3); auto it = fused_lstm_func_.getBody().back().rbegin(); EXPECT_EQ(it->getName().getStringRef(), @@ -161,33 +157,31 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) { EXPECT_EQ(it->getNumOperands(), 24); EXPECT_EQ(it->getNumResults(), 1); // cifg = false, so input2input is not None. - EXPECT_FALSE(it->getOperand(1).getType().isa()); + EXPECT_FALSE(mlir::isa(it->getOperand(1).getType())); // input layer norm is None - EXPECT_TRUE(it->getOperand(20).getType().isa()); + EXPECT_TRUE(mlir::isa(it->getOperand(20).getType())); // proj_bias is F32 - EXPECT_TRUE(it->getOperand(17) - .getType() - .cast() + EXPECT_TRUE(mlir::cast(it->getOperand(17).getType()) .getElementType() .isF32()); // output gate bias is 0 since it is out of bounds of the bias tensor, so // we set its value as a const tensor of specified size and value 0. - EXPECT_TRUE(mlir::cast( - it->getOpOperand(15).get().getDefiningOp()) - .getValue() - .cast() - .getValues()[0] - .getValue() - .isExactlyValue(0.0f)); + EXPECT_TRUE( + mlir::cast(mlir::cast( + it->getOpOperand(15).get().getDefiningOp()) + .getValue()) + .getValues()[0] + .getValue() + .isExactlyValue(0.0f)); EXPECT_EQ(fused_lstm_func_.getFunctionType().getNumResults(), 1); auto output_types = fused_lstm_func_.getFunctionType().getResults(); SmallVector output_shape{1, mlir::ShapedType::kDynamic}; - EXPECT_EQ(output_types[0].cast().getShape().size(), + EXPECT_EQ(mlir::cast(output_types[0]).getShape().size(), output_shape.size()); for (int i = 0; i < output_shape.size(); i++) { - EXPECT_EQ(output_types[0].cast().getDimSize(i), + EXPECT_EQ(mlir::cast(output_types[0]).getDimSize(i), output_shape[i]); } } @@ -215,7 +209,7 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimpleToFusedLSTMCoupleInputForget) { EXPECT_EQ(it->getNumOperands(), 24); EXPECT_EQ(it->getNumResults(), 1); // cifg = true, so input2input is None. - EXPECT_TRUE(it->getOperand(1).getType().isa()); + EXPECT_TRUE(mlir::isa(it->getOperand(1).getType())); } TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) { @@ -242,23 +236,25 @@ TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) { EXPECT_EQ(it->getNumOperands(), 24); EXPECT_EQ(it->getNumResults(), 1); // cifg = false, so input2input is not None. - EXPECT_FALSE(it->getOperand(1).getType().isa()); + EXPECT_FALSE(mlir::isa(it->getOperand(1).getType())); // input layer norm - EXPECT_FALSE(it->getOperand(20).getType().isa()); + EXPECT_FALSE(mlir::isa(it->getOperand(20).getType())); + EXPECT_EQ(mlir::cast(it->getOperand(20).getType()) + .getShape() + .size(), + 1); EXPECT_EQ( - it->getOperand(20).getType().cast().getShape().size(), - 1); - EXPECT_EQ(it->getOperand(20).getType().cast().getDimSize(0), - 3); + mlir::cast(it->getOperand(20).getType()).getDimSize(0), + 3); EXPECT_EQ(fused_ln_lstm_func_.getFunctionType().getNumResults(), 1); auto output_types = fused_ln_lstm_func_.getFunctionType().getResults(); SmallVector output_shape{1, mlir::ShapedType::kDynamic}; - EXPECT_EQ(output_types[0].cast().getShape().size(), + EXPECT_EQ(mlir::cast(output_types[0]).getShape().size(), output_shape.size()); for (int i = 0; i < output_shape.size(); i++) { - EXPECT_EQ(output_types[0].cast().getDimSize(i), + EXPECT_EQ(mlir::cast(output_types[0]).getDimSize(i), output_shape[i]); } } diff --git a/tensorflow/compiler/mlir/lite/utils/nms_utils.cc b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc index 5633068509faf4..cab3df456c0e00 100644 --- a/tensorflow/compiler/mlir/lite/utils/nms_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" namespace mlir { @@ -74,7 +75,7 @@ LogicalResult ConvertNMSPaddedFunc::VerifySignature() { // The TFLite fused op does not support batching yet. // TODO(b/158709815): Add support for batches with padded NMS. auto boxes_type = - func_.getFunctionType().getInput(0).dyn_cast(); + mlir::dyn_cast(func_.getFunctionType().getInput(0)); if (boxes_type == nullptr || !boxes_type.hasRank() || boxes_type.getRank() != 2) { return func_.emitWarning() << "TFLite does not support batched input for " @@ -121,7 +122,7 @@ LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions( failed(AddFloatAttr(func, attrs, "w_scale", &fbb))) return failure(); auto use_regular_nms = - attrs.get("use_regular_nms").dyn_cast_or_null(); + mlir::dyn_cast_or_null(attrs.get("use_regular_nms")); if (!use_regular_nms) { return func.emitError() << "use_regular_nms attribute is not set or not a bool"; @@ -137,7 +138,7 @@ LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions( LogicalResult ConvertSSDPostProcessFunc::AddIntAttr( func::FuncOp func, DictionaryAttr attrs, const std::string& attribute, flexbuffers::Builder* builder) { - auto int_attr = attrs.get(attribute).dyn_cast_or_null(); + auto int_attr = mlir::dyn_cast_or_null(attrs.get(attribute)); if (!int_attr) { return func.emitError() << attribute.c_str() << " attribute is not set or not an integer"; @@ -149,7 +150,7 @@ LogicalResult ConvertSSDPostProcessFunc::AddIntAttr( LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr( func::FuncOp func, DictionaryAttr attrs, const std::string& attribute, flexbuffers::Builder* builder) { - auto float_attr = attrs.get(attribute).dyn_cast_or_null(); + auto float_attr = mlir::dyn_cast_or_null(attrs.get(attribute)); if (!float_attr) { return func.emitError() << attribute.c_str() << " attribute is not set or not a float"; @@ -160,7 +161,7 @@ LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr( LogicalResult ConvertSSDPostProcessFunc::HasIntAttr( func::FuncOp func, DictionaryAttr attrs, const std::string& attribute) { - auto int_attr = attrs.get(attribute).dyn_cast_or_null(); + auto int_attr = mlir::dyn_cast_or_null(attrs.get(attribute)); if (!int_attr) { return func.emitWarning() << attribute.c_str() << " attribute is not set or not an integer"; @@ -170,7 +171,7 @@ LogicalResult ConvertSSDPostProcessFunc::HasIntAttr( LogicalResult ConvertSSDPostProcessFunc::HasFloatAttr( func::FuncOp func, DictionaryAttr attrs, const std::string& attribute) { - auto float_attr = attrs.get(attribute).dyn_cast_or_null(); + auto float_attr = mlir::dyn_cast_or_null(attrs.get(attribute)); if (!float_attr) { return func.emitWarning() << attribute.c_str() << " attribute is not set or not a float"; diff --git a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc index c7944b67406907..f6595331c02415 100644 --- a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -45,14 +46,15 @@ inline LogicalResult HasIntegerArrayWithSize(func::FuncOp* func, const DictionaryAttr& attrs, const std::string& attr_name, int N) { - ArrayAttr array_attr = attrs.get(attr_name).dyn_cast_or_null(); + ArrayAttr array_attr = + mlir::dyn_cast_or_null(attrs.get(attr_name)); if (array_attr == nullptr || array_attr.size() != N) { return func->emitWarning() << "'" << attr_name << "' attribute for " << kMaxUnpooling << " must be set and has size of " << N; } for (Attribute integer_attr : array_attr.getValue()) { - IntegerAttr value = integer_attr.dyn_cast(); + IntegerAttr value = mlir::dyn_cast(integer_attr); if (!value) { return func->emitWarning() << "'" << attr_name << "' attribute for " << kMaxUnpooling @@ -66,7 +68,8 @@ inline LogicalResult GetIntegerArraySafe( func::FuncOp* func, const DictionaryAttr& attrs, const std::string& attr_name, llvm::SmallVectorImpl* results, int N) { - ArrayAttr array_attr = attrs.get(attr_name).dyn_cast_or_null(); + ArrayAttr array_attr = + mlir::dyn_cast_or_null(attrs.get(attr_name)); if (array_attr == nullptr || array_attr.size() != N) { return func->emitError() << "'" << attr_name << "' attribute for " << kMaxUnpooling @@ -75,7 +78,7 @@ inline LogicalResult GetIntegerArraySafe( results->reserve(N); for (Attribute integer_attr : array_attr.getValue()) { - IntegerAttr value = integer_attr.dyn_cast(); + IntegerAttr value = mlir::dyn_cast(integer_attr); if (!value) { return func->emitError() << "'" << attr_name << "' attribute for " << kMaxUnpooling @@ -132,13 +135,12 @@ LogicalResult ConvertMaxUnpoolingFunc::VerifySignature() { } // Retrieves padding. - auto padding = attrs.get("padding").dyn_cast_or_null(); + auto padding = mlir::dyn_cast_or_null(attrs.get("padding")); if (!padding) { return func_.emitWarning() << "'padding' attribute for " << kMaxUnpooling << " is not set or not a string"; } - if (!padding.getValue().equals("VALID") && - !padding.getValue().equals("SAME")) { + if (padding.getValue() != "VALID" && padding.getValue() != "SAME") { return func_.emitWarning() << "Padding for " << kMaxUnpooling << " must be 'SAME' or 'VALID'"; } @@ -166,14 +168,14 @@ LogicalResult ConvertMaxUnpoolingFunc::CreateCustomOptions( pool_params.stride_width = strides[1]; // Retrieves padding. - auto padding = attrs.get("padding").dyn_cast_or_null(); + auto padding = mlir::dyn_cast_or_null(attrs.get("padding")); if (!padding) { return func_.emitError() << "'padding' attribute for " << kMaxUnpooling << " is not set or not a string"; } - if (padding.getValue().equals("VALID")) { + if (padding.getValue() == "VALID") { pool_params.padding = kTfLitePaddingValid; - } else if (padding.getValue().equals("SAME")) { + } else if (padding.getValue() == "SAME") { pool_params.padding = kTfLitePaddingSame; } else { return func_.emitError() @@ -224,22 +226,22 @@ LogicalResult ConvertDenseImageWarpFunc::VerifySignature() { } // Check types and shapes. - auto image_type = - func_.getFunctionType().getInput(0).dyn_cast_or_null(); + auto image_type = mlir::dyn_cast_or_null( + func_.getFunctionType().getInput(0)); if (!image_type || !image_type.getElementType().isF32() || image_type.getRank() != 4) { return func_.emitWarning() << "Image should be a 4D float tensor"; } - auto flow_type = - func_.getFunctionType().getInput(1).dyn_cast_or_null(); + auto flow_type = mlir::dyn_cast_or_null( + func_.getFunctionType().getInput(1)); if (!flow_type || !flow_type.getElementType().isF32() || flow_type.getRank() != 4) { return func_.emitWarning() << "Flow should be a 4D float tensor"; } - auto output_type = - func_.getFunctionType().getResult(0).dyn_cast_or_null(); + auto output_type = mlir::dyn_cast_or_null( + func_.getFunctionType().getResult(0)); if (!output_type || !output_type.getElementType().isF32() || output_type.getRank() != 4) { return func_.emitWarning() << "Output should be a 4D float tensor"; diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc index 7ce9c56086e691..5e9bcc16d27537 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc @@ -62,11 +62,13 @@ inline ConstBytesAttr CustomOption(OpBuilder* builder, } inline TensorType GetInputType(func::FuncOp func, int idx) { - return func.getFunctionType().getInput(idx).dyn_cast_or_null(); + return mlir::dyn_cast_or_null( + func.getFunctionType().getInput(idx)); } inline TensorType GetResultType(func::FuncOp func, int idx) { - return func.getFunctionType().getResult(idx).dyn_cast_or_null(); + return mlir::dyn_cast_or_null( + func.getFunctionType().getResult(idx)); } inline bool RankEquals(const TensorType& type, int rank) { @@ -89,7 +91,7 @@ LogicalResult VerifyWhitespaceTokenizer(func::FuncOp func) { // * 2nd output is the inner offset; // * 3rd output is the outer offset. auto input_type = GetInputType(func, 0); - if (!input_type || !input_type.getElementType().isa() || + if (!input_type || !mlir::isa(input_type.getElementType()) || !input_type.hasRank()) { return func.emitError() << "Input should be a string tensor"; } @@ -107,7 +109,7 @@ LogicalResult VerifyWhitespaceTokenizer(func::FuncOp func) { auto value_type = GetResultType(func, 0); if (!RankEquals(value_type, 1) || - !value_type.getElementType().isa()) { + !mlir::isa(value_type.getElementType())) { return func.emitError() << "1st output should be string tensor"; } if (func.getNumResults() > 1) { @@ -157,12 +159,14 @@ LogicalResult VerifyNgrams(func::FuncOp func) { int row_splits = func.getFunctionType().getInputs().size() - kRowSplits; if (row_splits == 0) { auto input_values = GetInputType(func, kValues); - if (!input_values || !input_values.getElementType().isa()) { + if (!input_values || + !mlir::isa(input_values.getElementType())) { return func.emitError() << "Input " << kValues << " should be a string tensor"; } auto output_values = GetResultType(func, kValues); - if (!output_values || !output_values.getElementType().isa()) { + if (!output_values || + !mlir::isa(output_values.getElementType())) { return func.emitError() << "Output " << kValues << " should be a string tensor"; } @@ -175,13 +179,13 @@ LogicalResult VerifyNgrams(func::FuncOp func) { } else { auto input_values = GetInputType(func, kValues); if (!RankEquals(input_values, 1) || - !input_values.getElementType().isa()) { + !mlir::isa(input_values.getElementType())) { return func.emitError() << "Input " << kValues << " should be a 1D string tensor"; } auto output_values = GetResultType(func, kValues); if (!RankEquals(output_values, 1) || - !output_values.getElementType().isa()) { + !mlir::isa(output_values.getElementType())) { return func.emitError() << "Output " << kValues << " should be a 1D string tensor"; } @@ -211,14 +215,14 @@ LogicalResult CreateNgramsCustomOption(func::FuncOp func, DictionaryAttr attrs, flexbuffers::Builder fbb; size_t start_map = fbb.StartMap(); - auto width = attrs.get("width").dyn_cast_or_null(); + auto width = mlir::dyn_cast_or_null(attrs.get("width")); if (!width) { return func.emitError() << "'width' attribute is not set or not an integer"; } fbb.Int("width", width.getInt()); auto string_separator = - attrs.get("string_separator").dyn_cast_or_null(); + mlir::dyn_cast_or_null(attrs.get("string_separator")); if (!string_separator) { return func.emitError() << "'string_separator' attribute is not set or not a string"; @@ -229,14 +233,14 @@ LogicalResult CreateNgramsCustomOption(func::FuncOp func, DictionaryAttr attrs, string_separator.getValue().size()); fbb.String("string_separator", string_separator_str); - auto axis = attrs.get("axis").dyn_cast_or_null(); + auto axis = mlir::dyn_cast_or_null(attrs.get("axis")); if (!axis) { return func.emitError() << "'axis' attribute is not set or not an integer"; } fbb.Int("axis", axis.getInt()); auto reduction_type = - attrs.get("reduction_type").dyn_cast_or_null(); + mlir::dyn_cast_or_null(attrs.get("reduction_type")); if (!reduction_type) { return func.emitError() << "'reduction_type' attribute is not set or not a string"; @@ -277,23 +281,23 @@ LogicalResult VerifySgnnProjection(func::FuncOp func, FuncAttr attr) { return func.emitError() << "Mismatched number of inputs and outputs."; } auto values_type = GetInputType(func, 0); - if (!values_type || !values_type.getElementType().isa()) { + if (!values_type || !mlir::isa(values_type.getElementType())) { return func.emitError() << "First input should be a string tensor"; } auto row_splits_type = GetInputType(func, 1); if (!row_splits_type || - !row_splits_type.getElementType().isa()) { + !mlir::isa(row_splits_type.getElementType())) { return func.emitError() << "Second input should be an integer tensor"; } auto hash_seed = - attr.getAttrs().get("hash_seed").dyn_cast_or_null(); + mlir::dyn_cast_or_null(attr.getAttrs().get("hash_seed")); if (!hash_seed) { return func.emitError() << "'hash_seed' attribute is not set or not an array"; } auto output_type = GetResultType(func, 0); - if (!output_type || !output_type.getElementType().isa() || + if (!output_type || !mlir::isa(output_type.getElementType()) || !RankEquals(output_type, 2)) { return func.emitError() << "Output should be a 2D float tensor."; } @@ -302,7 +306,8 @@ LogicalResult VerifySgnnProjection(func::FuncOp func, FuncAttr attr) { << "Output 2nd dimension should be the num of hash seeds."; } - auto buckets = attr.getAttrs().get("buckets").dyn_cast_or_null(); + auto buckets = + mlir::dyn_cast_or_null(attr.getAttrs().get("buckets")); if (!buckets) { return func.emitError() << "'buckets' attribute is not set or not int"; } @@ -316,15 +321,16 @@ LogicalResult CreateSgnnProjectionCustomOption( flexbuffers::Builder fbb; size_t start_map = fbb.StartMap(); - auto hash_seed = attrs.get("hash_seed").dyn_cast_or_null(); + auto hash_seed = mlir::dyn_cast_or_null(attrs.get("hash_seed")); auto vector_start = fbb.StartVector("hash_seed"); for (int i = 0; i < hash_seed.size(); i++) { fbb.Add(static_cast( - (hash_seed.getValue().data() + i)->dyn_cast().getInt())); + mlir::dyn_cast(*(hash_seed.getValue().data() + i)) + .getInt())); } fbb.EndVector(vector_start, /*typed=*/true, /*fixed=*/false); - auto buckets = attrs.get("buckets").dyn_cast_or_null(); + auto buckets = mlir::dyn_cast_or_null(attrs.get("buckets")); fbb.Int("buckets", buckets.getInt()); fbb.EndMap(start_map); diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc index 4ddb6b1c4411be..5138d7475452cd 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc @@ -33,7 +33,7 @@ namespace { void Register(const std::string& op_name, OpRegistry* registry) { registry->Register([op_name](OpRegistrationData* op_reg_data) -> Status { op_reg_data->op_def.set_name(op_name); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); }); } diff --git a/tensorflow/compiler/mlir/lite/utils/utils.h b/tensorflow/compiler/mlir/lite/utils/utils.h index 9fce1bc44387c3..d73bf37ebd748a 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.h +++ b/tensorflow/compiler/mlir/lite/utils/utils.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "llvm/ADT/ArrayRef.h" #include "mlir/IR/Attributes.h" // from @llvm-project @@ -62,6 +63,139 @@ inline bool OpHasSameStaticShapes(Operation* op) { return true; } +// Utility function to map final permutation to initial permutation +// initial -> permutation1 -> permutation2 -> final +inline DenseElementsAttr RemapPermutation(Value permutation1, + DenseElementsAttr perm2_const) { + SmallVector initial_permutation; + DenseElementsAttr perm1_const; + + SmallVector new_permutation; + if (matchPattern(permutation1, m_Constant(&perm1_const))) { + for (int32_t idx = 0; idx < perm1_const.getNumElements(); ++idx) { + initial_permutation.push_back(idx); + } + for (auto perm : perm2_const.getValues()) { + new_permutation.push_back( + initial_permutation[perm1_const + .getValues()[perm.getSExtValue()] + .getSExtValue()]); + } + } + + return mlir::DenseElementsAttr::get( + RankedTensorType::get( + {static_cast(new_permutation.size())}, + mlir::IntegerType::get(permutation1.getContext(), 32)), + llvm::ArrayRef(new_permutation)); +} + +// Utility function to map final permutation to initial permutation +// initial -> permutation1 -> permutation2 -> final +inline DenseElementsAttr RemapPermutation(Value permutation1, + Value permutation2) { + DenseElementsAttr perm2_const; + (void)matchPattern(permutation2, m_Constant(&perm2_const)); + + return RemapPermutation(permutation1, perm2_const); +} + +// Returns true if the transpose op is trivial. Trivial means that +// the permutation is a cyclic permutation of the original shape with only the +// identity dimensions permuted. +inline bool IsTransposeTrivial(llvm::ArrayRef input_shape, + Value perm) { + DenseElementsAttr perm_values_attr; + if (!matchPattern(perm, m_Constant(&perm_values_attr))) return false; + + SmallVector perm_values; + for (const auto& dim : perm_values_attr.getValues()) + perm_values.push_back(dim.getSExtValue()); + + // This should never happen unless the input graph is malformed. + if (input_shape.size() != perm_values.size()) { + return false; + } + + SmallVector old_major_index_ordering; + SmallVector new_major_index_ordering; + for (int i = 0, end = input_shape.size(); i < end; i++) { + if (input_shape[i] != 1) { + old_major_index_ordering.push_back(i); + } + + if (input_shape[perm_values[i]] != 1) { + new_major_index_ordering.push_back(perm_values[i]); + } + } + return (old_major_index_ordering == new_major_index_ordering); +} + +// Returns the permutation that maps the input shape to the output shape. +// This is only valid for trivial reshape ops. +inline DenseElementsAttr GetPermutationFromTrivialReshape( + ShapedType input_type, ShapedType output_type) { + ArrayRef in_shape = input_type.getShape(); + ArrayRef out_shape = output_type.getShape(); + + // Get the indexes of the non-identity dimensions and the identity dimensions + // in the input shape. + SmallVector input_nonidentity_dims_index_array; + SmallVector input_identity_dims_index_array; + + // Since the reshape is trivial, the input and output shapes should have the + // same number of dimensions. And the non-identity dimensions must be in the + // same cyclic order. + for (size_t idx = 0; idx < in_shape.size(); ++idx) { + if (in_shape[idx] != 1) { + input_nonidentity_dims_index_array.push_back(idx); + } else { + input_identity_dims_index_array.push_back(idx); + } + } + + // Get the permutation that maps the input shape to the output shape. + SmallVector permutation; + size_t nonidentity_dims_index_poiter = 0; + size_t identity_dims_index_pointer = 0; + for (auto out_dim : out_shape) { + if (out_dim != 1) { + permutation.push_back( + input_nonidentity_dims_index_array[nonidentity_dims_index_poiter++]); + } else { + permutation.push_back( + input_identity_dims_index_array[identity_dims_index_pointer++]); + } + } + + return mlir::DenseElementsAttr::get( + RankedTensorType::get( + {static_cast(permutation.size())}, + mlir::IntegerType::get(input_type.getContext(), 32)), + llvm::ArrayRef(permutation)); +} + +// Returns true if the reshape op is equivalent to a transpose op. +// This is true if the reshape op is a trivial reshape op, meaning no change in +// the order of non-identity dimensions. +inline bool IsReshapeEquivalentToTranspose(ShapedType input_type, + ShapedType output_type) { + std::vector in_shape{input_type.getShape().vec()}; + std::vector out_shape{output_type.getShape().vec()}; + + // If the reshape changes the number of dimensions so it cannot be interpreted + // as a transpose. + if (in_shape.size() != out_shape.size()) { + return false; + } + + in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1), + in_shape.end()); + out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1), + out_shape.end()); + return in_shape == out_shape; +} + // Checks if all elements in the constant attribute value are 1. inline bool IsAllOnesConstant(Attribute value) { auto values = value.cast().getValues(); diff --git a/tensorflow/compiler/mlir/lite/utils/utils.td b/tensorflow/compiler/mlir/lite/utils/utils.td index 42af8c67b2a7ce..067c95f1ce4c15 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.td +++ b/tensorflow/compiler/mlir/lite/utils/utils.td @@ -19,6 +19,9 @@ include "mlir/IR/OpBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/IR/PatternBase.td" +def CreateNoneValue : NativeCodeCall< + "$_builder.create($0.getLoc(), $_builder.getUnitAttr())">; + // Returns shape of a ranked tensor. // if called without a ranked tensor it will fail. def GetShape: NativeCodeCall<"GetShape($0)">; @@ -26,6 +29,27 @@ def GetShape: NativeCodeCall<"GetShape($0)">; // Constraint that values in list attribute are all ones. def IsAllOnesConstant : Constraint>; +// Constraint that checks if the transpose op is trivial. Trivial means that +// the permutation is a cyclic permutation of the original shape with only the +// identity dimensions permuted. +def IsTransposeTrivial : Constraint().getShape(), $1)">>; + +// Constraint that checks if the reshape op is equivalent to a transpose op. +// This is true if the reshape op is a trivial reshape op, meaning no change in +// the order of non-identity dimensions. +def IsReshapeEquivalentToTranspose : Constraint()," + "$1.getType().cast())">>; + +// Returns the permutation of the trivial reshape op, this will be used to +// construct the transpose op. +def GetPermutationFromTrivialReshape : NativeCodeCall< + "TFL::GetPermutationFromTrivialReshape(" + "$0.getType().cast()," + "$1.getType().cast())">; + // Constraint that checks if all values in offset between two // attributes are non-negative. def HasNonNegativeOffset : Constraint>; @@ -59,6 +83,10 @@ def SameElementType : Constraint< class GetTransposedType : NativeCodeCall< "GetTransposedType($0, " # perm # ")">; +// Function to map final permutation to initial permutation +// initial -> permutation1 -> permutation2 -> final +def RemapPermutation: NativeCodeCall<"RemapPermutation($0, $1)">; + // Checks if all of an ops inputs are the same static shape. // BUILD NOTE: "OpHasSameStaticShapes" here refers to the C++ function defined // in `utils/utils.h`. The `utils.h` header is included in `tfl_ops.h` so all diff --git a/tensorflow/compiler/mlir/lite/utils/validators.cc b/tensorflow/compiler/mlir/lite/utils/validators.cc index f4714e00e5f2a4..902d7b144ba69d 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.cc +++ b/tensorflow/compiler/mlir/lite/utils/validators.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TFL { @@ -36,45 +37,45 @@ bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x, auto elements = attr.getValue(); if (elements.size() != 4 || std::any_of(elements.begin(), elements.end(), - [](Attribute e) { return !e.isa(); })) + [](Attribute e) { return !mlir::isa(e); })) return false; - if (elements.front().cast().getInt() != 1 || - elements.back().cast().getInt() != 1) + if (mlir::cast(elements.front()).getInt() != 1 || + mlir::cast(elements.back()).getInt() != 1) return false; Builder b(op->getContext()); - *x = b.getI32IntegerAttr(elements[1].cast().getInt()); - *y = b.getI32IntegerAttr(elements[2].cast().getInt()); + *x = b.getI32IntegerAttr(mlir::cast(elements[1]).getInt()); + *y = b.getI32IntegerAttr(mlir::cast(elements[2]).getInt()); return true; } // Returns true if the attribute is an integer list of the form [1, X, Y, 1]. bool TFIntListIs1XY1(const Attribute attr) { - const auto &elements = attr.cast().getValue(); + const auto &elements = mlir::cast(attr).getValue(); if (elements.size() != 4 || std::any_of(elements.begin(), elements.end(), - [](Attribute e) { return !e.isa(); })) + [](Attribute e) { return !mlir::isa(e); })) return false; - if (elements.front().cast().getValue() != 1 || - elements.back().cast().getValue() != 1) + if (mlir::cast(elements.front()).getValue() != 1 || + mlir::cast(elements.back()).getValue() != 1) return false; return true; } // Returns true if the attribute is an integer list of the form [1, 1, X, Y]. bool TFIntListIs11XY(const Attribute attr) { - const auto &elements = attr.cast().getValue(); + const auto &elements = mlir::cast(attr).getValue(); if (elements.size() != 4 || std::any_of(elements.begin(), elements.end(), - [](Attribute e) { return !e.isa(); })) + [](Attribute e) { return !mlir::isa(e); })) return false; const Attribute *data = elements.data(); - if (data[0].cast().getValue() != 1 || - data[1].cast().getValue() != 1) + if (mlir::cast(data[0]).getValue() != 1 || + mlir::cast(data[1]).getValue() != 1) return false; return true; } @@ -91,17 +92,17 @@ bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x, auto elements = attr.getValue(); if (elements.size() != 5 || std::any_of(elements.begin(), elements.end(), - [](Attribute e) { return !e.isa(); })) + [](Attribute e) { return !mlir::isa(e); })) return false; - if (elements.front().cast().getInt() != 1 || - elements.back().cast().getInt() != 1) + if (mlir::cast(elements.front()).getInt() != 1 || + mlir::cast(elements.back()).getInt() != 1) return false; Builder b(op->getContext()); - *x = b.getI32IntegerAttr(elements[1].cast().getInt()); - *y = b.getI32IntegerAttr(elements[2].cast().getInt()); - *z = b.getI32IntegerAttr(elements[3].cast().getInt()); + *x = b.getI32IntegerAttr(mlir::cast(elements[1]).getInt()); + *y = b.getI32IntegerAttr(mlir::cast(elements[2]).getInt()); + *z = b.getI32IntegerAttr(mlir::cast(elements[3]).getInt()); return true; } @@ -109,10 +110,10 @@ bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x, // Returns true if every element of the attribute is 1. All elements of `attr` // must be `IntegerAttr`. bool TFIntListIsAllOnes(const Attribute attr) { - const auto &elements = attr.cast().getValue(); + const auto &elements = mlir::cast(attr).getValue(); return !std::any_of(elements.begin(), elements.end(), [](Attribute e) { - return e.cast().getValue() != 1; + return mlir::cast(e).getValue() != 1; }); } @@ -133,7 +134,7 @@ bool IsDimensionsDegenerateExceptLastOne(ArrayRef elements_shape) { } bool IsDimensionsDegenerateExceptLastOne(TypedAttr val) { - if (auto ranked_type = val.getType().dyn_cast()) { + if (auto ranked_type = mlir::dyn_cast(val.getType())) { return IsDimensionsDegenerateExceptLastOne(ranked_type.getShape()); } return false; diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h index 08d2e7b068b4be..0e7370c5fa499b 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.h +++ b/tensorflow/compiler/mlir/lite/utils/validators.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TFL { @@ -70,21 +71,21 @@ bool TFIntListIsAllOnes(Attribute attr); // Returns true iff the given value is a float32 tensor. // is "DT_FLOAT". inline bool TFTypeIsFloat32Tensor(Value value) { - auto tensorType = value.getType().dyn_cast(); + auto tensorType = mlir::dyn_cast(value.getType()); if (!tensorType) return false; return tensorType.getElementType().isF32(); } // Returns true iff the given value is a bf16 tensor. inline bool TFTypeIsBFloat16Tensor(Value value) { - auto tensorType = value.getType().dyn_cast(); + auto tensorType = mlir::dyn_cast(value.getType()); if (!tensorType) return false; return tensorType.getElementType().isBF16(); } // Returns true iff the given value is a f16 tensor. inline bool TFTypeIsHalfTensor(Value value) { - auto tensorType = value.getType().dyn_cast(); + auto tensorType = mlir::dyn_cast(value.getType()); if (!tensorType) return false; return tensorType.getElementType().isF16(); } diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc index f5912553f10dbe..8f3261f6574ff7 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/utils/name_utils.h" static inline absl::string_view StringRefToView(llvm::StringRef ref) { @@ -123,7 +124,7 @@ std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) { // If the location is none of the expected types, then simply use name // generated using the op type. Follow TF convention and append the result // index unless 0. - if (auto result = val.dyn_cast()) { + if (auto result = mlir::dyn_cast(val)) { if (result.getResultNumber() > 0) return llvm::formatv("{0}:{1}", result.getOwner()->getName().getStringRef(), @@ -131,7 +132,7 @@ std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) { return std::string(result.getOwner()->getName().getStringRef()); } // Use the ASM syntax for BlockArgument - if (auto arg = val.dyn_cast()) { + if (auto arg = mlir::dyn_cast(val)) { return "arg" + std::to_string(arg.getArgNumber()); } return ""; diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index 4841c0ad85714f..de1226c68c39d0 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -66,7 +66,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tosa/tfl_passes.h" #include "tensorflow/compiler/mlir/tosa/transforms/passes.h" #include "xla/mlir/framework/transforms/passes.h" -#include "xla/mlir_hlo/lhlo/transforms/passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/status_macros.h" #include "tensorflow/core/common_runtime/eager/context.h" @@ -92,7 +91,6 @@ static void RegisterPasses() { mlir::registerTensorFlowPasses(); mlir::TFDevice::registerTensorFlowDevicePasses(); mlir::mhlo::registerAllMhloPasses(); - mlir::lmhlo::registerAllLmhloPasses(); // These are in compiler/mlir/xla and not part of the above MHLO // passes. mlir::mhlo::registerTfXlaPasses(); diff --git a/tensorflow/compiler/mlir/quantization/common/BUILD b/tensorflow/compiler/mlir/quantization/common/BUILD index da122b67993af7..448b34717282b7 100644 --- a/tensorflow/compiler/mlir/quantization/common/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/BUILD @@ -40,16 +40,18 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", "//tensorflow/core:framework_lite", "//tensorflow/core/ir/types:Dialect", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc index 540eff26685968..1367e7e5eaa175 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc @@ -42,14 +42,14 @@ namespace mlir::quant { using ::mlir::stablehlo::DotGeneralOp; bool HasStaticShape(Value value) { - auto shaped_type = value.getType().dyn_cast(); + auto shaped_type = mlir::dyn_cast(value.getType()); if (!shaped_type) return false; return shaped_type.hasStaticShape(); } bool HasStaticShapeAtDims(Value value, const ArrayRef dims) { - auto shaped_type = value.getType().dyn_cast(); + auto shaped_type = mlir::dyn_cast(value.getType()); if (!shaped_type || !shaped_type.hasRank()) return false; for (auto dim : dims) { @@ -59,9 +59,9 @@ bool HasStaticShapeAtDims(Value value, const ArrayRef dims) { } Type CloneTypeWithNewElementType(Type old_type, Type element_type) { - if (!old_type.isa()) return {}; + if (!mlir::isa(old_type)) return {}; - return old_type.cast().clone(element_type); + return mlir::cast(old_type).clone(element_type); } SmallVector CloneOpWithReplacedOperands( @@ -133,9 +133,11 @@ absl::StatusOr IsDotGeneralFullyConnected(DotGeneralOp dot_general_op) { const ArrayRef rhs_contracting_dims = dot_dimension_numbers.getRhsContractingDimensions(); const int64_t input_rank = - dot_general_op.getOperand(0).getType().dyn_cast().getRank(); + mlir::dyn_cast(dot_general_op.getOperand(0).getType()) + .getRank(); const int64_t filter_rank = - dot_general_op.getOperand(1).getType().dyn_cast().getRank(); + mlir::dyn_cast(dot_general_op.getOperand(1).getType()) + .getRank(); // The following conditions are such requirements: // - rank(lhs) is 1 or 2 // - rank(rhs) = 2 @@ -164,7 +166,8 @@ std::optional GetDotGeneralQuantizationDim( DotGeneralOp dot_general_op) { if (dot_general_op == nullptr) return std::nullopt; const int64_t filter_rank = - dot_general_op.getOperand(1).getType().dyn_cast().getRank(); + mlir::dyn_cast(dot_general_op.getOperand(1).getType()) + .getRank(); // To quantize rhs per-channel, we currently only consider the case where // `stablehlo.dot_general` is legalizable to `tfl.fully_connected`. @@ -174,4 +177,8 @@ std::optional GetDotGeneralQuantizationDim( return filter_rank - 1; } +bool ContainsConvOrDot(StringRef str) { + return str.contains("_conv") || str.contains("_dot_general"); +} + } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h index 490a77a3b73ffa..e94f9359d6fad2 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h @@ -42,6 +42,10 @@ namespace mlir::quant { constexpr char kAttrMapAttribute[] = "attr_map"; +// Name of the string attribute attached to `XlaCallModuleOp`, which is the +// textproto representation of `Method`. +inline constexpr StringRef kQuantizationMethodAttr = "_quantization_method"; + // Permutation from the NHWC tensor format to NCHW. This is an inverse // permutation of `kNchwToNhwcPermutation`. inline constexpr std::array kNhwcToNchwPermutation = {0, 3, 1, 2}; @@ -65,7 +69,7 @@ bool HasStaticShapeAtDims(Value value, ArrayRef dims); // Whether `value` has known rank of `rank`. Returns false when it is not a // `ShapedType` or its rank is unknown. inline bool HasRankOf(Value value, const int64_t rank) { - auto shaped_type = value.getType().dyn_cast_or_null(); + auto shaped_type = mlir::dyn_cast_or_null(value.getType()); return shaped_type && shaped_type.hasRank() && shaped_type.getRank() == rank; } @@ -215,7 +219,7 @@ Operation* FindOperandOfType(Operation* op) { // Returns the function attribute for the given call op which is lifted for // quantization. inline FlatSymbolRefAttr GetFuncAttr(TF::PartitionedCallOp call_op) { - return call_op.getFAttr().template dyn_cast(); + return mlir::dyn_cast(call_op.getFAttr()); } inline FlatSymbolRefAttr GetFuncAttr(TF::XlaCallModuleOp call_op) { @@ -248,6 +252,9 @@ absl::StatusOr IsDotGeneralFullyConnected( std::optional GetDotGeneralQuantizationDim( ::mlir::stablehlo::DotGeneralOp dot_general_op); +// Checks if a `StringRef` contains 'conv' or 'dot_general'. +bool ContainsConvOrDot(StringRef str); + } // namespace mlir::quant #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_ATTRS_AND_CONSTRAINTS_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc index ca0df77f81b51c..720616309afe38 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc @@ -98,6 +98,21 @@ constexpr absl::string_view kModuleXlaCallModule = R"mlir( } )mlir"; +constexpr absl::string_view kModuleDotWeightOnlyPtq = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", _quantization_method = "weight_only_ptq { }"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_dot_general_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x2xf32>) -> tensor + return %0 : tensor + } + } +)mlir"; + constexpr absl::string_view kModuleXlaCallModuleNoEntryNoQuantTrait = R"mlir( module { func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { @@ -128,8 +143,8 @@ constexpr absl::string_view kModulePartitionedCall = R"mlir( constexpr absl::string_view kModuleHybridQuantized = R"mlir( module { - func.func @main(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3x!quant.uniform> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32>) { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> + func.func @main(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3x!quant.uniform> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32>) { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> return %0 : tensor<1x3xf32> } } @@ -526,5 +541,31 @@ TEST_F(AttrsAndConstraintsTest, DotGeneralBatchMatmulReturnsNullQuantDim) { EXPECT_THAT(GetDotGeneralQuantizationDim(dot_general_op), Eq(std::nullopt)); } +TEST_F(AttrsAndConstraintsTest, ContainsConvOrDotTrue) { + OwningOpRef module_op = + ParseModuleOpString(kModuleDotWeightOnlyPtq); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + const StringRef function_name = GetEntryFunctionName(call_op); + EXPECT_TRUE(ContainsConvOrDot(function_name)); +} + +TEST_F(AttrsAndConstraintsTest, ContainsConvOrDotFalse) { + OwningOpRef module_op = + ParseModuleOpString(kModuleXlaCallModuleNoEntryNoQuantTrait); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + const StringRef function_name = GetEntryFunctionName(call_op); + EXPECT_FALSE(ContainsConvOrDot(function_name)); +} + } // namespace } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.cc b/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.cc index 7bd7424e4d1c6a..6ddebac1ff00f9 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.cc +++ b/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/ir/QuantOpsDialect.cc.inc" namespace mlir::quant::ir { @@ -49,20 +50,20 @@ OpFoldResult StorageCastOp::fold(FoldAdaptor) { /// The quantization specification should match the expressed type. static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) { - if (auto typeAttr = quantSpec.dyn_cast()) { + if (auto typeAttr = mlir::dyn_cast(quantSpec)) { Type spec = typeAttr.getValue(); - if (spec.isa()) return false; + if (mlir::isa(spec)) return false; // The spec should be either a quantized type which is compatible to the // expressed type, or a primitive type which is as same as the // (element type of) the expressed type. - if (auto quantizedType = spec.dyn_cast()) + if (auto quantizedType = mlir::dyn_cast(spec)) return quantizedType.isCompatibleExpressedType(expressed); - if (auto tensorType = expressed.dyn_cast()) + if (auto tensorType = mlir::dyn_cast(expressed)) return spec == tensorType.getElementType(); - if (auto vectorType = expressed.dyn_cast()) + if (auto vectorType = mlir::dyn_cast(expressed)) return spec == vectorType.getElementType(); } return false; @@ -97,13 +98,13 @@ LogicalResult QuantizeRegionOp::verify() { } LogicalResult StatisticsOp::verify() { - auto tensorArg = getArg().getType().dyn_cast(); + auto tensorArg = mlir::dyn_cast(getArg().getType()); if (!tensorArg) return emitOpError("arg needs to be tensor type."); // Verify layerStats attribute. { auto layerStatsType = getLayerStats().getShapedType(); - if (!layerStatsType.getElementType().isa()) { + if (!mlir::isa(layerStatsType.getElementType())) { return emitOpError("layerStats must have a floating point element type"); } if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { @@ -120,7 +121,7 @@ LogicalResult StatisticsOp::verify() { std::multiplies()); auto axisStatsType = getAxisStats()->getShapedType(); - if (!axisStatsType.getElementType().isa()) { + if (!mlir::isa(axisStatsType.getElementType())) { return emitOpError("axisStats must have a floating point element type"); } if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 || diff --git a/tensorflow/compiler/mlir/quantization/common/ir/QuantOpsBase.td b/tensorflow/compiler/mlir/quantization/common/ir/QuantOpsBase.td index d891ed17ee1443..fb762f933d6f00 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/QuantOpsBase.td +++ b/tensorflow/compiler/mlir/quantization/common/ir/QuantOpsBase.td @@ -27,7 +27,6 @@ include "mlir/IR/OpBase.td" def Quant_Dialect : Dialect { let name = "quantization"; let cppNamespace = "::mlir::quant::ir"; - let usePropertiesForAttributes = 0; } #endif // QUANTIZATION_BASE \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc index 5a200241af00dd..c0509bb8243bfc 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc +++ b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc @@ -15,54 +15,66 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" +#include +#include +#include +#include +#include #include +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project -using namespace mlir; -using namespace mlir::quantfork; +namespace mlir::quantfork { -static bool isQuantizablePrimitiveType(Type inputType) { - return inputType.isa(); +static bool isQuantizablePrimitiveType(Type input_type) { + return isa(input_type); } ExpressedToQuantizedConverter ExpressedToQuantizedConverter::forInputType( - Type inputType) { - if (inputType.isa()) { - Type elementType = inputType.cast().getElementType(); - if (!isQuantizablePrimitiveType(elementType)) - return ExpressedToQuantizedConverter{inputType, nullptr}; - return ExpressedToQuantizedConverter{inputType, elementType}; + Type input_type) { + if (isa(input_type)) { + Type element_type = cast(input_type).getElementType(); + if (!isQuantizablePrimitiveType(element_type)) + return ExpressedToQuantizedConverter{input_type, nullptr}; + return ExpressedToQuantizedConverter{input_type, element_type}; } // Supported primitive type (which just is the expressed type). - if (isQuantizablePrimitiveType(inputType)) - return ExpressedToQuantizedConverter{inputType, inputType}; + if (isQuantizablePrimitiveType(input_type)) + return ExpressedToQuantizedConverter{input_type, input_type}; // Unsupported. - return ExpressedToQuantizedConverter{inputType, nullptr}; + return ExpressedToQuantizedConverter{input_type, nullptr}; } Type ExpressedToQuantizedConverter::convert( - quant::QuantizedType elementalType) const { - assert(expressedType && "convert() on unsupported conversion"); - if (auto tensorType = inputType.dyn_cast()) - return RankedTensorType::get(tensorType.getShape(), elementalType); - if (auto tensorType = inputType.dyn_cast()) - return UnrankedTensorType::get(elementalType); - if (auto vectorType = inputType.dyn_cast()) - return VectorType::get(vectorType.getShape(), elementalType); + quant::QuantizedType elemental_type) const { + assert(expressed_type && "convert() on unsupported conversion"); + if (auto tensor_type = dyn_cast(input_type)) + return RankedTensorType::get(tensor_type.getShape(), elemental_type); + if (auto tensor_type = dyn_cast(input_type)) + return UnrankedTensorType::get(elemental_type); + if (auto vector_type = dyn_cast(input_type)) + return VectorType::get(vector_type.getShape(), elemental_type); // If the expressed types match, just use the new elemental type. - if (elementalType.getExpressedType() == expressedType) return elementalType; + if (elemental_type.getExpressedType() == expressed_type) { + return elemental_type; + } // Unsupported. return nullptr; } ElementsAttr UniformQuantizedPerAxisValueConverter::convert( - Attribute realValue) { - if (auto attr = realValue.dyn_cast()) { + Attribute real_value) { + if (auto attr = dyn_cast(real_value)) { return convert(attr); } - // TODO: handles sparse elements attribute return nullptr; } @@ -71,26 +83,30 @@ DenseElementsAttr UniformQuantizedPerAxisValueConverter::convert( // Creates the converter for each chunk. Normally the size of the // quantization dim is 3, so we can cache all the converters. ShapedType type = attr.getType(); - size_t dimSize = type.getDimSize(quantizationDim); - if (dimSize != scales.size()) { + std::size_t dim_size = type.getDimSize(quantization_dim_); + if (dim_size != scales_.size()) { return {}; } SmallVector converters; - converters.reserve(dimSize); - for (int i = 0, e = dimSize; i != e; ++i) { + converters.reserve(dim_size); + for (int i = 0, e = dim_size; i != e; ++i) { converters.push_back(getPerChunkConverter(i)); } // Scan the elements of the dense elements attributes and quantize them by // using the right quantization parameters. - int64_t flattenIndex = 0; + int64_t flatten_index = 0; auto shape = type.getShape(); - int64_t chunkSize = - std::accumulate(std::next(shape.begin(), quantizationDim + 1), + int64_t chunk_size = + std::accumulate(std::next(shape.begin(), quantization_dim_ + 1), shape.end(), 1, std::multiplies()); - Type newElementType = IntegerType::get(attr.getContext(), storageBitWidth); - return attr.mapValues(newElementType, [&](const APFloat &old) { - int chunkIndex = (flattenIndex++) / chunkSize; - return converters[chunkIndex % dimSize].quantizeFloatToInt(old); + Type new_element_type = + IntegerType::get(attr.getContext(), storage_bit_width_); + return attr.mapValues(new_element_type, [&](const APFloat &old) { + int chunk_index = flatten_index / chunk_size; + flatten_index++; + return converters[chunk_index % dim_size].quantizeFloatToInt(old); }); } + +} // namespace mlir::quantfork diff --git a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h index b6f65e455d0c09..c0c6c30e0d6e58 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h +++ b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h @@ -16,130 +16,139 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_UNIFORMSUPPORT_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_UNIFORMSUPPORT_H_ +#include +#include +#include +#include +#include #include #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project - -namespace mlir { -namespace quantfork { - -/// Performs type conversion from an arbitrary input type to a type -/// that is expressed by a QuantizedType. -/// -/// This handles cases where the inputType is a supported primitive type -/// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported -/// elemental type. -/// -/// Since conversion often involves introspecting some attributes of the -/// input type in order to determine how to represent it, this is a two step -/// process. +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir::quantfork { + +// Performs type conversion from an arbitrary input type to a type +// that is expressed by a QuantizedType. +// +// This handles cases where the inputType is a supported primitive type +// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported +// elemental type. +// +// Since conversion often involves introspecting some attributes of the +// input type in order to determine how to represent it, this is a two step +// process. struct ExpressedToQuantizedConverter { - /// Creates a converter for the given input type. - static ExpressedToQuantizedConverter forInputType(Type inputType); + // Creates a converter for the given input type. + static ExpressedToQuantizedConverter forInputType(Type input_type); - /// Converts the inputType to be based on the given elemental type, - /// returning the new type (or nullptr and emit an error on failure). - Type convert(quant::QuantizedType elementalType) const; + // Converts the inputType to be based on the given elemental type, + // returning the new type (or nullptr and emit an error on failure). + Type convert(quant::QuantizedType elemental_type) const; - /// Whether the conversion is legal. - explicit operator bool() const { return (bool)expressedType; } + // Whether the conversion is legal. + explicit operator bool() const { return (bool)expressed_type; } - /// The input type that is being converted from. - /// This may be an elemental or composite type. - const Type inputType; + // The input type that is being converted from. + // This may be an elemental or composite type. + const Type input_type; - /// Supported, elemental expressed type (i.e. f32). - /// Will be nullptr if conversion is not supported. - const Type expressedType; + // Supported, elemental expressed type (i.e. f32). + // Will be nullptr if conversion is not supported. + const Type expressed_type; }; -/// Reference implementation of converting between real numbers and values -/// represented by a UniformQuantizedType. -/// Note that this is not expected to be speedy and may be superseded eventually -/// by a more optimal implementation. -/// Also, the interface assumes that quantization is done per-layer and will -/// need to be wider for various per-channel schemes. As such, this is a -/// placeholder. +// Reference implementation of converting between real numbers and values +// represented by a UniformQuantizedType. +// Note that this is not expected to be speedy and may be superseded eventually +// by a more optimal implementation. +// Also, the interface assumes that quantization is done per-layer and will +// need to be wider for various per-channel schemes. As such, this is a +// placeholder. class UniformQuantizedValueConverter { public: explicit UniformQuantizedValueConverter( - quant::UniformQuantizedType uniformType) + quant::UniformQuantizedType uniform_type) : UniformQuantizedValueConverter( - uniformType.getScale(), - static_cast(uniformType.getZeroPoint()), - static_cast(uniformType.getStorageTypeMin()), - static_cast(uniformType.getStorageTypeMax()), - uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) { - assert(uniformType.getExpressedType().isa()); - assert(uniformType.getStorageType().isSignlessInteger()); + uniform_type.getScale(), + static_cast(uniform_type.getZeroPoint()), + static_cast(uniform_type.getStorageTypeMin()), + static_cast(uniform_type.getStorageTypeMax()), + uniform_type.getStorageTypeIntegralWidth(), + uniform_type.isSigned()) { + assert(isa(uniform_type.getExpressedType())); + assert(uniform_type.getStorageType().isSignlessInteger()); } - UniformQuantizedValueConverter(double scale, double zeroPoint, - double clampMin, double clampMax, - uint32_t storageBitWidth, bool isSigned) - : scale(scale), - zeroPoint(zeroPoint), - clampMin(clampMin), - clampMax(clampMax), - scaleDouble(scale), - zeroPointDouble(zeroPoint), - clampMinDouble(clampMin), - clampMaxDouble(clampMax), - storageBitWidth(storageBitWidth), - isSigned(isSigned), - roundMode(APFloat::rmNearestTiesToAway) {} - - UniformQuantizedValueConverter(double scale, double zeroPoint, - const APFloat &clampMin, - const APFloat &clampMax, - uint32_t storageBitWidth, bool isSigned) - : scale(scale), - zeroPoint(zeroPoint), - clampMin(clampMin), - clampMax(clampMax), - scaleDouble(scale), - zeroPointDouble(zeroPoint), - clampMinDouble(clampMin.convertToDouble()), - clampMaxDouble(clampMax.convertToDouble()), - storageBitWidth(storageBitWidth), - isSigned(isSigned), - roundMode(APFloat::rmNearestTiesToAway) {} - - virtual APInt quantizeFloatToInt(APFloat expressedValue) const { + UniformQuantizedValueConverter(double scale, double zero_point, + double clamp_min, double clamp_max, + uint32_t storage_bit_width, bool is_signed) + : scale_(scale), + zero_point_(zero_point), + clamp_min_(clamp_min), + clamp_max_(clamp_max), + scale_double_(scale), + zero_point_double_(zero_point), + clamp_min_double_(clamp_min), + clamp_max_double_(clamp_max), + storage_bit_width_(storage_bit_width), + is_signed_(is_signed), + round_mode_(APFloat::rmNearestTiesToAway) {} + + UniformQuantizedValueConverter(double scale, double zero_point, + const APFloat& clamp_min, + const APFloat& clamp_max, + uint32_t storage_bit_width, bool is_signed) + : scale_(scale), + zero_point_(zero_point), + clamp_min_(clamp_min), + clamp_max_(clamp_max), + scale_double_(scale), + zero_point_double_(zero_point), + clamp_min_double_(clamp_min.convertToDouble()), + clamp_max_double_(clamp_max.convertToDouble()), + storage_bit_width_(storage_bit_width), + is_signed_(is_signed), + round_mode_(APFloat::rmNearestTiesToAway) {} + + virtual APInt quantizeFloatToInt(APFloat expressed_value) const { // This function is a performance critical code path in quantization // since it runs for each single float parameter value. // Specialize f32->u8/i8 case to optimize performance. - if (&expressedValue.getSemantics() == &APFloat::IEEEsingle() && - storageBitWidth == 8 && - roundMode == llvm::APFloatBase::rmNearestTiesToAway) { - return quantizeF32ToInt8(expressedValue); + if (&expressed_value.getSemantics() == &APFloat::IEEEsingle() && + storage_bit_width_ == 8 && + round_mode_ == llvm::APFloatBase::rmNearestTiesToAway) { + return quantizeF32ToInt8(expressed_value); } bool lossy; - expressedValue.convert(scale.getSemantics(), roundMode, &lossy); - // fixedpoint = clamp(clampMin, clampMax, ( - // roundHalfToEven(expressed / scale) + zeroPoint)) - APFloat scaled = (expressedValue / scale); - scaled.roundToIntegral(roundMode); - scaled.add(zeroPoint, roundMode); - APFloat fixedpoint = llvm::minimum(scaled, clampMax); - fixedpoint = llvm::maximum(fixedpoint, clampMin); - - llvm::APSInt result(storageBitWidth, !isSigned); - fixedpoint.convertToInteger(result, roundMode, &lossy); + expressed_value.convert(scale_.getSemantics(), round_mode_, &lossy); + // fixed_point = clamp(clamp_min, clamp_max, ( + // roundHalfToEven(expressed / scale) + zero_point)) + APFloat scaled = (expressed_value / scale_); + scaled.roundToIntegral(round_mode_); + scaled.add(zero_point_, round_mode_); + APFloat fixed_point = llvm::minimum(scaled, clamp_max_); + fixed_point = llvm::maximum(fixed_point, clamp_min_); + + llvm::APSInt result(storage_bit_width_, !is_signed_); + fixed_point.convertToInteger(result, round_mode_, &lossy); return std::move(result); } - int64_t quantizeFloatToInt64(APFloat expressedValue) const { - APInt qValue = quantizeFloatToInt(std::move(expressedValue)); - return isSigned ? qValue.getSExtValue() : qValue.getZExtValue(); + int64_t quantizeFloatToInt64(APFloat expressed_value) const { + const APInt q_value = quantizeFloatToInt(std::move(expressed_value)); + return is_signed_ ? q_value.getSExtValue() : q_value.getZExtValue(); } virtual ~UniformQuantizedValueConverter() = default; @@ -147,94 +156,92 @@ class UniformQuantizedValueConverter { private: // An optimized implementation to quantize f32 to i8/u8 with C++ native // arithmetic. - virtual APInt quantizeF32ToInt8(APFloat expressedValue) const { - assert(&expressedValue.getSemantics() == &APFloat::IEEEsingle()); - assert(storageBitWidth == 8); - assert(roundMode == llvm::APFloatBase::rmNearestTiesToAway); + virtual APInt quantizeF32ToInt8(const APFloat& expressed_value) const { + assert(&expressed_value.getSemantics() == &APFloat::IEEEsingle()); + assert(storage_bit_width_ == 8); + assert(round_mode_ == llvm::APFloatBase::rmNearestTiesToAway); - const float realValue = expressedValue.convertToFloat(); + const float real_value = expressed_value.convertToFloat(); - const double scaled = realValue / scaleDouble + zeroPointDouble; + const double scaled = real_value / scale_double_ + zero_point_double_; // Round to nearest integer with halfway cases rounded away from zero. - const double scaledRounded = std::round(scaled); - const double clamped = - std::min(std::max(scaledRounded, clampMinDouble), clampMaxDouble); - - uint64_t signlessResult; - if (isSigned) { - int64_t clampedInt = static_cast(clamped); - memcpy(&signlessResult, &clampedInt, sizeof(clampedInt)); + const double scaled_rounded = std::round(scaled); + const double clamped = std::min(std::max(scaled_rounded, clamp_min_double_), + clamp_max_double_); + + uint64_t signless_result; + if (is_signed_) { + int64_t clamped_int = static_cast(clamped); + memcpy(&signless_result, &clamped_int, sizeof(clamped_int)); } else { - signlessResult = static_cast(clamped); + signless_result = static_cast(clamped); } - return APInt(storageBitWidth, signlessResult); + return APInt(storage_bit_width_, signless_result); } // Keep both APFloat and double versions of the quantization parameters // around since they will be used in generic and specialized arithmetic, // respectively. - const APFloat scale; - const APFloat zeroPoint; - const APFloat clampMin; - const APFloat clampMax; - - const double scaleDouble; - const double zeroPointDouble; - const double clampMinDouble; - const double clampMaxDouble; - - const uint32_t storageBitWidth; - const bool isSigned; - const llvm::APFloat::roundingMode roundMode; + const APFloat scale_; + const APFloat zero_point_; + const APFloat clamp_min_; + const APFloat clamp_max_; + + const double scale_double_; + const double zero_point_double_; + const double clamp_min_double_; + const double clamp_max_double_; + + const uint32_t storage_bit_width_; + const bool is_signed_; + const llvm::APFloat::roundingMode round_mode_; }; -/// An utility class to quantize an attribute by the per-axis quantization -/// parameters. The size of the quantization dim in the converted elements -/// attribute should matche the size of of scales/zeroPoints vectors in the -/// quantization parameters. +// An utility class to quantize an attribute by the per-axis quantization +// parameters. The size of the quantization dim in the converted elements +// attribute should match the size of of scales/zero_points vectors in the +// quantization parameters. class UniformQuantizedPerAxisValueConverter { public: explicit UniformQuantizedPerAxisValueConverter( - quant::UniformQuantizedPerAxisType uniformType) - : scales(uniformType.getScales()), - zeroPoints(uniformType.getZeroPoints()), - clampMin(static_cast(uniformType.getStorageTypeMin())), - clampMax(static_cast(uniformType.getStorageTypeMax())), - storageBitWidth(uniformType.getStorageTypeIntegralWidth()), - isSigned(uniformType.isSigned()), - quantizationDim(uniformType.getQuantizedDimension()) { - assert(uniformType.getExpressedType().isa()); - assert(uniformType.getStorageType().isSignlessInteger()); - assert(scales.size() == zeroPoints.size()); + quant::UniformQuantizedPerAxisType uniform_type) + : scales_(uniform_type.getScales()), + zero_points_(uniform_type.getZeroPoints()), + clamp_min_(static_cast(uniform_type.getStorageTypeMin())), + clamp_max_(static_cast(uniform_type.getStorageTypeMax())), + storage_bit_width_(uniform_type.getStorageTypeIntegralWidth()), + is_signed_(uniform_type.isSigned()), + quantization_dim_(uniform_type.getQuantizedDimension()) { + assert(isa(uniform_type.getExpressedType())); + assert(uniform_type.getStorageType().isSignlessInteger()); + assert(scales_.size() == zero_points_.size()); } - /// Quantize an Attribute by the quantization parameters. Return nullptr if - /// the conversion fails or the input array isn't an ElementsAttr. - ElementsAttr convert(Attribute realValue); + // Quantize an Attribute by the quantization parameters. Return nullptr if + // the conversion fails or the input array isn't an ElementsAttr. + ElementsAttr convert(Attribute real_value); private: - /// Quantize an DenseFPElementsAttr by the quantization parameters. + // Quantize an DenseFPElementsAttr by the quantization parameters. DenseElementsAttr convert(DenseFPElementsAttr attr); - /// Get a uniform converter for the index-th chunk along the quantizationDim. - /// All the elements in this chunk is quantized by the returned converter. + // Get a uniform converter for the index-th chunk along the quantizationDim. + // All the elements in this chunk is quantized by the returned converter. UniformQuantizedValueConverter getPerChunkConverter(int index) const { - UniformQuantizedValueConverter converter(scales[index], zeroPoints[index], - clampMin, clampMax, - storageBitWidth, isSigned); - return converter; + return UniformQuantizedValueConverter(scales_[index], zero_points_[index], + clamp_min_, clamp_max_, + storage_bit_width_, is_signed_); } - const ArrayRef scales; - const ArrayRef zeroPoints; - const APFloat clampMin; - const APFloat clampMax; - const uint32_t storageBitWidth; - const bool isSigned; - int32_t quantizationDim; + const ArrayRef scales_; + const ArrayRef zero_points_; + const APFloat clamp_min_; + const APFloat clamp_max_; + const uint32_t storage_bit_width_; + const bool is_signed_; + int32_t quantization_dim_; }; -} // namespace quantfork -} // namespace mlir +} // namespace mlir::quantfork #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_UNIFORMSUPPORT_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc index 050bf45d7b5a46..bf894948d4cec8 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -136,7 +138,7 @@ ValueRange CreateTFXlaCallModuleOp(OpBuilder& builder, const Location location, SmallVector shape_attrs; for (const Type result_type : output_types) { shape_attrs.push_back( - tf_type::ShapeAttr::get(ctx, result_type.cast())); + tf_type::ShapeAttr::get(ctx, mlir::cast(result_type))); } auto empty_array_attr = ArrayAttr::get(ctx, {}); auto platforms = ArrayAttr::get(ctx, {StringAttr::get(ctx, kPlatformCpu)}); @@ -266,9 +268,9 @@ LogicalResult SetAttributeMap(MLIRContext& context, const NamedAttribute& attribute = attributes[idx]; // Skip the following steps if the attribute value is `NullAttribute`. if (const auto string_attr = - attribute.getValue().dyn_cast_or_null(); + mlir::dyn_cast_or_null(attribute.getValue()); string_attr != nullptr && - string_attr.getValue().equals(kNullAttributeValue)) { + string_attr.getValue() == kNullAttributeValue) { continue; } @@ -479,10 +481,9 @@ bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr) { rhs_out_idx_start >= batch_dim_size; } -absl::StatusOr GetQuantizationMethod( - TF::XlaCallModuleOp xla_call_module_op) { +absl::StatusOr GetQuantizationMethod(absl::Nonnull op) { const auto quantization_method_attr = - xla_call_module_op->getAttrOfType(kQuantizationMethodAttr); + op->getAttrOfType(kQuantizationMethodAttr); if (!quantization_method_attr) { return absl::InvalidArgumentError(absl::StrCat( "Attribute ", kQuantizationMethodAttr.str(), " is not found.")); @@ -498,15 +499,40 @@ absl::StatusOr GetQuantizationMethod( return quantization_method; } -Method GetQuantizationMethodOrDefault(TF::XlaCallModuleOp xla_call_module_op) { - absl::StatusOr method = GetQuantizationMethod(xla_call_module_op); +Method GetQuantizationMethodOrDefault(absl::Nonnull op) { + absl::StatusOr method = GetQuantizationMethod(op); if (method.status().code() == absl::StatusCode::kInternal) { // This indicates that the `Method` protobuf string is corrupt, but this // function ignores it and returns the default instance. - xla_call_module_op->emitError(absl::StrCat( - "Failed to get quantization method: ", method.status().ToString())); + op->emitError(absl::StrCat("Failed to get quantization method: ", + method.status().ToString())); } return method.ok() ? *method : Method::default_instance(); } +bool HasWeightOnlyPtqMethod(TF::XlaCallModuleOp xla_call_module_op) { + Method method = GetQuantizationMethodOrDefault(xla_call_module_op); + return method.has_weight_only_ptq(); +} + +bool IsWeightOnlyQuantizableOp(const Operation& op) { + if (auto call_op = dyn_cast(op)) { + StringRef entry_function_name = GetEntryFunctionName(call_op); + absl::StatusOr quantization_method = GetQuantizationMethod(call_op); + return ContainsConvOrDot(entry_function_name) && quantization_method.ok() && + quantization_method->has_weight_only_ptq(); + } + return false; +} + +SmallVector GetSortedFunctions(ModuleOp module_op) { + auto iterator_range = module_op.getOps(); + SmallVector func_ops(iterator_range.begin(), + iterator_range.end()); + absl::c_sort(func_ops, [](func::FuncOp op1, func::FuncOp op2) { + return op1.getName() < op2.getName(); + }); + return func_ops; +} + } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h index bfef9a13df1a01..2d22816e725a48 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h @@ -15,11 +15,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_LIFT_AS_FUNCTION_CALL_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_LIFT_AS_FUNCTION_CALL_H_ +#include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "llvm/ADT/SmallVector.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -43,10 +47,6 @@ constexpr StringRef kCompositeFuncPrefix = "composite_"; inline constexpr StringRef kOriginalStablehloEntryFunctionAttrName = "_original_entry_function"; -// Name of the string attribute attached to `XlaCallModuleOp`, which is the -// textproto representation of `Method`. -inline constexpr StringRef kQuantizationMethodAttr = "_quantization_method"; - // FunctionCallOpType to be generated as the function call operator when // function lifting will happen. enum FunctionCallOpType { TFPartitionedCallOp = 0, TFXlaCallModuleOp = 1 }; @@ -62,19 +62,20 @@ bool IsInStableHloOpRegion(Operation* op); // Checks if a given einsum op is supported for XlaDotV2 quantization. bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr); -// Gets the quantization method from the given `XlaCallModuleOp`. It is -// retrieved from the `kQuantizationMethodAttr` string attribute. Returns +// Gets the quantization method from `op`. It is retrieved from the +// `kQuantizationMethodAttr` string attribute. Returns // `absl::InvalidArgumentError` when the attribute doesn't exist. Returns // `absl::InternalError` when parsing the attribute to `Method` failed. +// `op` must be non-null. absl::StatusOr<::stablehlo::quantization::Method> GetQuantizationMethod( - TF::XlaCallModuleOp xla_call_module_op); + absl::Nonnull op); -// Gets the quantization method from the given `XlaCallModuleOp`. It is -// retrieved from the `kQuantizationMethodAttr` string attribute. Returns a -// default instance of `Method` iff the attribute doesn't exist or the attribute -// contains an invalid textproto for `Method`. +// Gets the quantization method from `op`. It is retrieved from the +// `kQuantizationMethodAttr` string attribute. Returns a default instance of +// `Method` iff the attribute doesn't exist or the attribute contains an invalid +// textproto for `Method`. `op` must be non-null. ::stablehlo::quantization::Method GetQuantizationMethodOrDefault( - TF::XlaCallModuleOp xla_call_module_op); + absl::Nonnull op); // Creates a function to wrap the section between arguments and results. // The generated function call op type will be decided by the given call_op_type @@ -99,6 +100,17 @@ SmallVector LiftAsFunctionCall(OpBuilder& builder, Location location, // Used to attach bias to einsum argument list. SmallVector AppendToVector(ArrayRef arguments, Value append); +// Checks if the `Method` attatched to the given `tf.XlaCallModule` op has +// `WeightOnlyPtq`. +bool HasWeightOnlyPtqMethod(TF::XlaCallModuleOp xla_call_module_op); + +// Checks if an op is a `tf.XlaCallModule` op, contains 'conv' or 'dot_general' +// in its name and has `Method` with `WeightOnlyPtq`. +bool IsWeightOnlyQuantizableOp(const Operation& op); + +// Lists the functions in a ModuleOp sorted by their names. +SmallVector GetSortedFunctions(ModuleOp module_op); + } // namespace mlir::quant #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_LIFT_AS_FUNCTION_CALL_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc index 5e5e103ba72018..4a40a70700f9b2 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc @@ -47,6 +47,8 @@ namespace { using ::stablehlo::quantization::Method; using ::testing::HasSubstr; using ::testing::NotNull; +using ::testing::SizeIs; +using ::testing::StrEq; using ::tsl::protobuf::util::MessageDifferencer; using ::tsl::testing::IsOk; using ::tsl::testing::StatusIs; @@ -118,10 +120,11 @@ TEST_F(LiftAsFunctionCallTest, FunctionLiftedAsXlaCallModuleOp) { FindOperationOfType(entry_func); EXPECT_TRUE(isa(lifted_op)); - EXPECT_EQ(lifted_op->getAttr("_original_entry_function").cast(), - "composite_dot_general_fn_1"); EXPECT_EQ( - lifted_dot_general_op->getAttr("precision_config").cast(), + mlir::cast(lifted_op->getAttr("_original_entry_function")), + "composite_dot_general_fn_1"); + EXPECT_EQ( + mlir::cast(lifted_dot_general_op->getAttr("precision_config")), builder_.getArrayAttr(SmallVector( 1, mlir::stablehlo::PrecisionAttr::get( ctx_.get(), mlir::stablehlo::Precision::DEFAULT)))); @@ -144,8 +147,9 @@ TEST_F(LiftAsFunctionCallTest, FunctionNoAttrLiftedAsXlaCallModuleOp) { "composite_dot_general_fn", operands, results)[0] .getDefiningOp(); EXPECT_TRUE(isa(lifted_op)); - EXPECT_EQ(lifted_op->getAttr("_original_entry_function").cast(), - "composite_dot_general_fn_1"); + EXPECT_EQ( + mlir::cast(lifted_op->getAttr("_original_entry_function")), + "composite_dot_general_fn_1"); } TEST_F(LiftAsFunctionCallTest, EinsumSupportedForXlaDotV2Succeeds) { @@ -351,6 +355,179 @@ TEST_F( const Method method = GetQuantizationMethodOrDefault(*xla_call_module_op); EXPECT_TRUE(MessageDifferencer::Equals(method, Method::default_instance())); } +constexpr absl::string_view kModuleDotWeightOnlyPtq = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", _quantization_method = "weight_only_ptq { }"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_dot_general_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x2xf32>) -> tensor + return %0 : tensor + } + } +)mlir"; + +TEST_F(LiftAsFunctionCallTest, HasWeightOnlyPtqMethodExists) { + OwningOpRef module_op = + ParseModuleOpString(kModuleDotWeightOnlyPtq); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_TRUE(HasWeightOnlyPtqMethod(call_op)); +} + +TEST_F(LiftAsFunctionCallTest, HasWeightOnlyPtqMethodDifferentMethod) { + const absl::string_view kModuleDotNoQuantization = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", _quantization_method = "no_quantization { }"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_dot_general_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x2xf32>) -> tensor + return %0 : tensor + } + } + )mlir"; + OwningOpRef module_op = + ParseModuleOpString(kModuleDotNoQuantization); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_FALSE(HasWeightOnlyPtqMethod(call_op)); +} + +TEST_F(LiftAsFunctionCallTest, HasWeightOnlyPtqMethodNoMethod) { + const absl::string_view kModuleXlaCallModule = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_fn_1, _original_entry_function = "composite_fn_1", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + return %arg0 : tensor + } + } + )mlir"; + OwningOpRef module_op = ParseModuleOpString(kModuleXlaCallModule); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_FALSE(HasWeightOnlyPtqMethod(call_op)); +} + +TEST_F(LiftAsFunctionCallTest, IsWeightOnlyQuantizableOpDot) { + OwningOpRef module_op = + ParseModuleOpString(kModuleDotWeightOnlyPtq); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_TRUE(IsWeightOnlyQuantizableOp(*call_op)); +} + +TEST_F(LiftAsFunctionCallTest, IsWeightOnlyQuantizableOpNotTfXlaCallModuleOp) { + const absl::string_view kModulePartitionedCallDot = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.PartitionedCall"(%arg0, %1, %0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_dot_general_fn_1, _quantization_method = "weight_only_ptq { }"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_dot_general_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x2xf32>) -> tensor + return %0 : tensor + } + } + )mlir"; + OwningOpRef module_op = + ParseModuleOpString(kModulePartitionedCallDot); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_FALSE(IsWeightOnlyQuantizableOp(*call_op)); +} + +TEST_F(LiftAsFunctionCallTest, IsWeightOnlyQuantizableOpNoConvNoDot) { + constexpr absl::string_view kModuleXlaCallModule = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_fn_1, _original_entry_function = "composite_fn_1", _tfl_quant_trait = "fully_quantizable", _quantization_method = "weight_only_ptq { }"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + return %arg0 : tensor + } + } + )mlir"; + OwningOpRef module_op = ParseModuleOpString(kModuleXlaCallModule); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_FALSE(IsWeightOnlyQuantizableOp(*call_op)); +} + +TEST_F(LiftAsFunctionCallTest, GetSortedFunctions) { + constexpr absl::string_view kModuleXlaCallModule = R"mlir( + module { + func.func @conv_3_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%1, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + func.return %2: tensor<1x3x3x4xf32> + } + + func.func @conv_1_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%1, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + func.return %2: tensor<1x3x3x4xf32> + } + + func.func @conv_2_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%1, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + func.return %2: tensor<1x3x3x4xf32> + } + } + )mlir"; + OwningOpRef module_op = ParseModuleOpString(kModuleXlaCallModule); + ASSERT_TRUE(module_op); + + SmallVector funcs = GetSortedFunctions(*module_op); + ASSERT_THAT(funcs, SizeIs(3)); + EXPECT_THAT(funcs[0].getSymName(), StrEq("conv_1_fn")); + EXPECT_THAT(funcs[1].getSymName(), StrEq("conv_2_fn")); + EXPECT_THAT(funcs[2].getSymName(), StrEq("conv_3_fn")); +} } // namespace } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc index 216a4a2b3d58e9..7645177160fc62 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc @@ -91,10 +91,11 @@ bool HasPerAxisQuantizedOperand(Operation* op) { for (int i = 0; i < op->getNumOperands(); ++i) { if (auto dq_op = dyn_cast_or_null( op->getOperand(i).getDefiningOp())) { - auto type = dq_op.getArg().getType().cast().getElementType(); + auto type = + mlir::cast(dq_op.getArg().getType()).getElementType(); if (auto per_axis_qtype = - QuantizedType::getQuantizedElementType(type) - .dyn_cast_or_null()) { + mlir::dyn_cast_or_null( + QuantizedType::getQuantizedElementType(type))) { return true; } } @@ -179,7 +180,7 @@ bool QuantizationDriver::SetConstantResultParams(Operation* op) { /*num_bits=*/8, is_signed_, /*narrow_range=*/is_weight, legacy_float_scale_); } - if (const auto quant_type = final_type.dyn_cast_or_null(); + if (const auto quant_type = mlir::dyn_cast_or_null(final_type); quant_type != nullptr) { return SetResultParams(op, /*result_index=*/0, quant_type); } @@ -225,7 +226,7 @@ QuantizedType QuantizationDriver::GetBiasParams( if (bias_op != nullptr) { Type bias_type = bias_op->getResult(0).getType(); if (bias_type != builder_.getNoneType()) { - const int bias_rank = bias_type.dyn_cast().getRank(); + const int bias_rank = mlir::dyn_cast(bias_type).getRank(); adjusted_quant_dim = bias_rank > 1 ? bias_rank - 1 : 0; } } @@ -489,12 +490,12 @@ QuantizedType QuantizationDriver::GetQuantParamsForSameScaleConstraint( void QuantizationDriver::PreprocessConstantOps() { fn_.walk([&](arith::ConstantOp cst) { // Non-float tensors are neither weights nor require quantization. - const auto type = cst.getType().dyn_cast(); - if (!type || !type.getElementType().isa()) return; + const auto type = mlir::dyn_cast(cst.getType()); + if (!type || !mlir::isa(type.getElementType())) return; // Skip if the value is NaN or INF. // Otherwise the illegal scale/zp will be calculated. - auto float_attr = cst.getValueAttr().dyn_cast(); + auto float_attr = mlir::dyn_cast(cst.getValueAttr()); if (float_attr && (float_attr.getValues().empty() || !float_attr.getValues()[0].isFinite())) { return; @@ -620,7 +621,7 @@ bool QuantizationDriver::ShouldCheckBiasScale( auto affine_op = dyn_cast(op); auto bias_op = op->getOperand(bias_index).getDefiningOp(); if (!affine_op || !bias_op || input_indices.size() != 2) return false; - if (!bias_op.getValue().isa()) return false; + if (!mlir::isa(bias_op.getValue())) return false; filter_index = affine_op.GetAffineOperandIndex(); if (!op->getOperand(filter_index).getDefiningOp()) { return false; @@ -658,12 +659,12 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( QuantState filter_state = GetOperandQuantState(op, filter_index); auto bias_op = op->getOperand(bias_index).getDefiningOp(); const double input_scale = - input_state.params.cast().getScale(); + mlir::cast(input_state.params).getScale(); - auto bias_values = bias_op.getValue().cast(); + auto bias_values = mlir::cast(bias_op.getValue()); // Restrict maximum absolute value of bias within INT_MAX / 2, to make some // room for accumulator. - if (auto bias_quantized_type = params.dyn_cast(); + if (auto bias_quantized_type = mlir::dyn_cast(params); bias_quantized_type != nullptr) { double bias_half_range = 0.0f; for (auto bias : bias_values.getValues()) { @@ -691,7 +692,7 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( } const auto filter_quantized_type = - filter_state.params.cast(); + mlir::cast(filter_state.params); changed |= SetOperandParams( op, filter_index, UniformQuantizedType::getChecked( @@ -703,10 +704,10 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( filter_quantized_type.getStorageTypeMax()), /*override=*/true); } else if (auto bias_quantized_type = - params.dyn_cast(); + mlir::dyn_cast(params); bias_quantized_type != nullptr) { const auto filter_quantized_type = - filter_state.params.cast(); + mlir::cast(filter_state.params); std::vector new_bias_scales = bias_quantized_type.getScales().vec(); std::vector new_filter_scales = filter_quantized_type.getScales().vec(); @@ -822,21 +823,22 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { // Use the final state to set all the operands' parameters. for (int i = 0; i < op->getNumOperands(); ++i) { - if (auto type = op->getOperand(i).getType().dyn_cast()) { + if (auto type = + mlir::dyn_cast(op->getOperand(i).getType())) { // Without this check, it will accidentally propagate the quantization // information by the shared non-float tensors. - if (type.getElementType().isa()) + if (mlir::isa(type.getElementType())) changed |= SetOperandParams(op, i, params); } } // Use the final state to set all the results' parameters. for (int i = 0; i < op->getNumResults(); ++i) - if (auto type = op->getResult(i).getType().dyn_cast(); + if (auto type = mlir::dyn_cast(op->getResult(i).getType()); type != nullptr) { // Without this check, it will accidentally propagate the quantization // information by the shared non-float-tensors. - if (type.getElementType().isa()) + if (mlir::isa(type.getElementType())) changed |= SetResultParams(op, i, params); } } diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc index cc82c09894b46b..f017054cbe7044 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc @@ -159,10 +159,9 @@ TEST_F(ApplyQuantizationParamsPropagationTest, FinalizeInsertsQDQOps) { ASSERT_NE(filter_qcast_op, nullptr); EXPECT_TRUE(isa(filter_qcast_op)); EXPECT_TRUE(isa(filter_dcast_op)); - EXPECT_TRUE(isa(filter_qcast_op->getResult(0) - .getType() - .cast() - .getElementType())); + EXPECT_TRUE(isa( + mlir::cast(filter_qcast_op->getResult(0).getType()) + .getElementType())); } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc index f6c561be98d49b..8e5496106c5279 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc @@ -29,7 +29,6 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project @@ -125,14 +124,13 @@ QuantizedType ResetMinMaxFromNumBits(const QuantizedType type, const auto& recalculate_zero_point = [&](int64_t zero_point) -> int64_t { return qmax - std::round((storage_type_max - zero_point) / rate); }; - if (auto q_type = type.dyn_cast()) { + if (auto q_type = dyn_cast(type)) { const double scale = recalculate_scale(q_type.getScale()); const double zero_point = recalculate_zero_point(q_type.getZeroPoint()); return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(), scale, zero_point, qmin, qmax); - } else if (auto q_type = - type.dyn_cast()) { + } else if (auto q_type = dyn_cast(type)) { const int size = q_type.getScales().size(); SmallVector scales(size); SmallVector zero_points(size); @@ -155,7 +153,7 @@ quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast( const ArrayRef shape, const quant::UniformQuantizedPerAxisType qtype, const Type target, const int quant_dim) { - const auto shaped = target.dyn_cast(); + const auto shaped = dyn_cast(target); if (!shaped) return {}; const ArrayRef new_shape = shaped.getShape(); @@ -236,52 +234,54 @@ Type GetQuantizedType(Builder builder, const Type input_type, SmallVector effective_mins, effective_maxs; ExpandVerySmallRange(min, max, effective_mins, effective_maxs); - quant::QuantizedType quantizedEleType; + quant::QuantizedType quantized_element_type; if (min.size() == 1 && max.size() == 1 && quant_dim == -1) { - quantizedEleType = quantfork::fakeQuantAttrsToType( + quantized_element_type = quantfork::fakeQuantAttrsToType( builder.getUnknownLoc(), storage_type_width, effective_mins[0], - effective_maxs[0], narrow_range, converter.expressedType, is_signed); + effective_maxs[0], narrow_range, converter.expressed_type, is_signed); if (legacy_float_scale) { - quantizedEleType = - DownCastScale(quantizedEleType, effective_mins[0], effective_maxs[0], - builder.getUnknownLoc()); + quantized_element_type = + DownCastScale(quantized_element_type, effective_mins[0], + effective_maxs[0], builder.getUnknownLoc()); } } else if (min.size() == max.size()) { - auto shape = input_type.dyn_cast(); + auto shape = dyn_cast(input_type); if (!shape || shape.getRank() <= quant_dim || static_cast(min.size()) != shape.getDimSize(quant_dim)) { return {}; } // The quantization dim is set to the last dimension. - quantizedEleType = quantfork::fakeQuantAttrsToType( + quantized_element_type = quantfork::fakeQuantAttrsToType( builder.getUnknownLoc(), storage_type_width, quant_dim, effective_mins, - effective_maxs, narrow_range, converter.expressedType, is_signed); + effective_maxs, narrow_range, converter.expressed_type, is_signed); if (legacy_float_scale) { - quantizedEleType = DownCastScale(quantizedEleType, effective_mins, - effective_maxs, builder.getUnknownLoc()); + quantized_element_type = + DownCastScale(quantized_element_type, effective_mins, effective_maxs, + builder.getUnknownLoc()); } } - if (!quantizedEleType) return {}; + if (!quantized_element_type) return {}; // Use fake quant configured bit-widths (only supported for // 1 < num_bits < 8 bits) instead of using 8-bit defaults. if (use_fake_quant_num_bits && storage_type_width > 1 && storage_type_width < 8 && - quantizedEleType.getStorageTypeMax() > + quantized_element_type.getStorageTypeMax() > QType::getDefaultMinimumForInteger(is_signed, storage_type_width)) { const auto resetEleType = ResetMinMaxFromNumBits( - quantizedEleType, storage_type_width, narrow_range, is_signed); + quantized_element_type, storage_type_width, narrow_range, is_signed); return converter.convert(resetEleType); } - return converter.convert(quantizedEleType); + return converter.convert(quantized_element_type); } // TODO(fengliuai): promote this utility method to mlir QuantOps. TypeAttr RescaleQuantizedType(const Type input, const Attribute factor) { - const auto factor_values = factor.dyn_cast_or_null(); + const auto factor_values = dyn_cast_or_null(factor); if (!factor_values) return {}; - const auto ele_type = quant::QuantizedType::getQuantizedElementType(input); - if (!ele_type) return {}; - if (auto qtype = ele_type.dyn_cast()) { + const auto element_type = + quant::QuantizedType::getQuantizedElementType(input); + if (!element_type) return {}; + if (auto qtype = dyn_cast(element_type)) { const ArrayRef scales = qtype.getScales(); // Broadcasting hasn't been implemented yet. if (static_cast(scales.size()) != factor_values.getNumElements()) @@ -315,8 +315,8 @@ TypeAttr GetQuantizedTypeAttr(const Builder builder, const Type input_type, const bool legacy_float_scale, const bool use_fake_quant_num_bits) { SmallVector min_value, max_value; - const auto mins = min.dyn_cast(); - const auto maxs = max.dyn_cast(); + const auto mins = dyn_cast(min); + const auto maxs = dyn_cast(max); if (mins && maxs) { min_value.reserve(mins.getNumElements()); max_value.reserve(maxs.getNumElements()); @@ -327,8 +327,8 @@ TypeAttr GetQuantizedTypeAttr(const Builder builder, const Type input_type, max_value.push_back(FloatAttr::getValueAsDouble(*it)); } } else { - const auto fmin = min.dyn_cast(); - const auto fmax = max.dyn_cast(); + const auto fmin = dyn_cast(min); + const auto fmax = dyn_cast(max); if (fmin && fmax) { min_value.push_back(fmin.getValueAsDouble()); max_value.push_back(fmax.getValueAsDouble()); @@ -348,14 +348,14 @@ TypeAttr CastQuantizedTypeAttrFromExpressedType(const Builder builder, const TypeAttr source, const Type target, const int axis) { - const auto source_type = source.getValue().dyn_cast_or_null(); + const auto source_type = dyn_cast_or_null(source.getValue()); if (!source_type) return {}; const auto src_ele_type = source_type.getElementType(); - auto qtype = src_ele_type.dyn_cast(); + auto qtype = dyn_cast(src_ele_type); // Reset the quantization dimensions if it is per-axis. if (const auto per_axis = - qtype.dyn_cast_or_null()) { + dyn_cast_or_null(qtype)) { // For the pass-through ops, we don't know which the dimension will be the // new quantization dimension. Only if the new quantization dimension can // be inferred, it is safe to reset the per-axis quantized type. @@ -396,7 +396,9 @@ void ExtractMinMaxFromAttr(const DenseFPElementsAttr values, const int dim_size, } } else { int64_t flatten_index = 0; - for (auto it = values.begin(); it != values.end(); ++it, ++flatten_index) { + auto begin = values.begin(); + auto end = values.end(); + for (auto it = begin; it != end; ++it, ++flatten_index) { const double ele_value = FloatAttr::getValueAsDouble(*it); const int slice_index = flatten_index / slice_size; const int channel_index = slice_index % dim_size; @@ -427,7 +429,7 @@ Type GetUniformQuantizedTypeForWeight( SmallVector mins(1, std::numeric_limits::max()); SmallVector maxs(1, std::numeric_limits::min()); - const auto fp = attr.dyn_cast(); + const auto fp = dyn_cast(attr); if (!fp) return {}; // Computes the effective min/max values of the attribute values. @@ -438,7 +440,7 @@ Type GetUniformQuantizedTypeForWeight( GetQuantizedType(builder, attr.getType(), mins[0], maxs[0], /*quant_dim=*/-1, num_bits, narrow_range, is_signed, legacy_float_scale, use_fake_quant_num_bits); - if (const auto ele_type = type.dyn_cast_or_null()) + if (const auto ele_type = dyn_cast_or_null(type)) return ele_type.getElementType(); return {}; @@ -449,7 +451,7 @@ Type GetUniformQuantizedPerAxisTypeForWeight( const unsigned num_bits, const bool is_signed, const bool narrow_range, const bool legacy_float_scale, const bool use_fake_quant_num_bits) { const Builder builder(attr.getContext()); - const auto shape = attr.getType().cast().getShape(); + const auto shape = cast(attr.getType()).getShape(); if (static_cast(shape.size()) <= quant_dim) return {}; // `symmetric` can only be used when it is `signed` and `narrow_range`. if (symmetric && (!is_signed || !narrow_range)) return {}; @@ -460,7 +462,7 @@ Type GetUniformQuantizedPerAxisTypeForWeight( std::multiplies()); SmallVector mins(dim_size, std::numeric_limits::max()); SmallVector maxs(dim_size, std::numeric_limits::min()); - const auto fp = attr.dyn_cast(); + const auto fp = dyn_cast(attr); if (!fp) return {}; // Computes the effective min/max values of the attribute values. @@ -469,7 +471,7 @@ Type GetUniformQuantizedPerAxisTypeForWeight( const auto type = GetQuantizedType( builder, attr.getType(), mins, maxs, quant_dim, num_bits, narrow_range, is_signed, legacy_float_scale, use_fake_quant_num_bits); - if (auto ele_type = type.dyn_cast_or_null()) + if (auto ele_type = dyn_cast_or_null(type)) return ele_type.getElementType(); return {}; @@ -495,28 +497,28 @@ quant::QuantizedType GetUniformQuantizedTypeForBias( expressed_type = op_type.getExpressedType(); if (const auto type = - op_type.dyn_cast()) { + dyn_cast(op_type)) { if (axis_size != 1 && axis_size != type.getScales().size()) return {}; if (quant_dim != -1 && quant_dim != type.getQuantizedDimension()) return {}; axis_size = type.getScales().size(); quant_dim = type.getQuantizedDimension(); - } else if (!op_type.isa()) { + } else if (!isa(op_type)) { return {}; } } // The scale from the UniformQuantizedTypes is broadcasted if there are // UniformQuantizedPerAxisTypes. - llvm::SmallVector scales(axis_size, 1.0); + SmallVector scales(axis_size, 1.0); for (const auto op_type : op_types) { if (const auto type = - op_type.dyn_cast()) { + dyn_cast(op_type)) { for (const auto& index_scale : llvm::enumerate(type.getScales())) { scales[index_scale.index()] *= index_scale.value(); } } else if (const auto type = - op_type.dyn_cast()) { + dyn_cast(op_type)) { for (int index = 0; index < axis_size; ++index) { scales[index] *= type.getScale(); } @@ -541,7 +543,7 @@ quant::QuantizedType GetUniformQuantizedTypeForBias( /*flags=*/true, storage_type, expressed_type, scales[0], /*zeroPoint=*/0, storage_type_min, storage_type_max); } else { - llvm::SmallVector zero_points(axis_size, 0); + SmallVector zero_points(axis_size, 0); // If the bias is a 1-D tensor, set the `quantizedDimension` to 0. // If the bias rank is larger than 1 because it was already broadcasted // to match the output shape, use the last index. @@ -555,30 +557,28 @@ quant::QuantizedType GetUniformQuantizedTypeForBias( ElementsAttr QuantizeLegacy(const Attribute real_value, const Type tensor_type) { - if (!real_value.isa() || + if (!isa(real_value) || !quant::QuantizedType::getQuantizedElementType(tensor_type)) { return {}; } - const auto real_values_attr = real_value.cast(); + const auto real_values_attr = cast(real_value); auto q_type = quant::QuantizedType::getQuantizedElementType(tensor_type); std::vector real_values; - llvm::SmallVector quantized_attr; + SmallVector quantized_attr; real_values.reserve(real_values_attr.getNumElements()); quantized_attr.reserve(real_values_attr.getNumElements()); std::transform(real_values_attr.begin(), real_values_attr.end(), std::back_inserter(real_values), [&](APFloat value) -> float { return value.convertToFloat(); }); - const ShapedType new_dense_type = - q_type.castExpressedToStorageType(real_values_attr.getType()) - .dyn_cast_or_null(); - const int width = - q_type.getStorageType().dyn_cast().getWidth(); + const ShapedType new_dense_type = dyn_cast_or_null( + q_type.castExpressedToStorageType(real_values_attr.getType())); + const int width = dyn_cast(q_type.getStorageType()).getWidth(); if (width == 8 && q_type.getStorageTypeMax() == 127 && q_type.getStorageTypeMin() == -127) { std::vector quantized_values(real_values_attr.getNumElements()); - if (auto uniform_type = q_type.dyn_cast()) { + if (auto uniform_type = dyn_cast(q_type)) { float min, max, scale; tflite::tensor_utils::SymmetricQuantizeFloats( real_values.data(), real_values.size(), quantized_values.data(), &min, @@ -588,7 +588,7 @@ ElementsAttr QuantizeLegacy(const Attribute real_value, return Quantize(real_value, tensor_type); } } else if (auto uniform_type = - q_type.dyn_cast()) { + dyn_cast(q_type)) { std::vector scales_inv; std::vector dimension; dimension.insert(dimension.end(), new_dense_type.getShape().begin(), @@ -617,7 +617,7 @@ ElementsAttr QuantizeLegacy(const Attribute real_value, // not correctly quantized by legacy quantizer so call the new Quantize. return Quantize(real_value, tensor_type); } else if (width == 16) { - if (const auto uniform_type = q_type.dyn_cast()) { + if (const auto uniform_type = dyn_cast(q_type)) { const auto quantized_values = tflite::optimize::utils::SymmetricQuantizeFloatsToInt16( real_values.data(), real_values.size(), uniform_type.getScale()); @@ -630,10 +630,10 @@ ElementsAttr QuantizeLegacy(const Attribute real_value, } } else if (width == 32) { std::vector scales; - if (const auto uniform_type = q_type.dyn_cast()) { + if (const auto uniform_type = dyn_cast(q_type)) { scales.push_back(uniform_type.getScale()); } else if (const auto uniform_type = - q_type.dyn_cast()) { + dyn_cast(q_type)) { scales.insert(scales.end(), uniform_type.getScales().begin(), uniform_type.getScales().end()); } else { @@ -656,8 +656,8 @@ ElementsAttr Quantize(const Attribute real_value, const Type tensor_type) { if (const auto q_type = quant::QuantizedType::getQuantizedElementType(tensor_type)) { Type converted_type; - return quantfork::quantizeAttr(real_value, q_type, converted_type) - .dyn_cast_or_null(); + return dyn_cast_or_null( + quantfork::quantizeAttr(real_value, q_type, converted_type)); } return {}; } @@ -678,10 +678,9 @@ quant::QuantizedType DownCastScale(QuantizedType type, if (!type) return type; SmallVector scales(mins.size()); SmallVector zero_points(mins.size()); - if (auto q_type = type.dyn_cast()) { + if (auto q_type = dyn_cast(type)) { zero_points.push_back(q_type.getZeroPoint()); - } else if (auto q_type = - type.dyn_cast()) { + } else if (auto q_type = dyn_cast(type)) { zero_points = {q_type.getZeroPoints().begin(), q_type.getZeroPoints().end()}; } @@ -701,13 +700,12 @@ quant::QuantizedType DownCastScale(QuantizedType type, } } } - if (auto q_type = type.dyn_cast()) { + if (auto q_type = dyn_cast(type)) { return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(), scales[0], zero_points[0], q_type.getStorageTypeMin(), q_type.getStorageTypeMax()); - } else if (auto q_type = - type.dyn_cast()) { + } else if (auto q_type = dyn_cast(type)) { return quant::UniformQuantizedPerAxisType::get( q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(), scales, zero_points, q_type.getQuantizedDimension(), @@ -722,8 +720,8 @@ quant::QuantizedType DownCastScale(QuantizedType type, static bool PreferResultScale(Operation* op) { int float_operands = 0; for (auto operand : op->getOperands()) { - if (auto operand_type = operand.getType().dyn_cast()) { - if (operand_type.getElementType().isa()) { + if (auto operand_type = dyn_cast(operand.getType())) { + if (isa(operand_type.getElementType())) { if (++float_operands > 1) return true; } } @@ -733,22 +731,22 @@ static bool PreferResultScale(Operation* op) { std::unique_ptr GetDefaultQuantScaleSpec(Operation* op) { auto spec = std::make_unique(); - if (llvm::isa(op)) { + if (isa(op)) { spec->has_same_scale_requirement = true; spec->required_same_scale_func = [op](const bool sign, const int bit_width) { - return llvm::cast(op) + return cast(op) .RequiredSameOperandsAndResultsScale(sign, bit_width); }; spec->required_same_quantized_axes_func = [op]() { - return llvm::cast(op).RequiredSameQuantizedAxes(); + return cast(op).RequiredSameQuantizedAxes(); }; } - if (llvm::isa(op)) { + if (isa(op)) { spec->has_fixed_output_range = true; spec->fixed_output_range_func = [op](bool sign, int bit_width) { - return llvm::cast(op).GetFixedOutputRange( - sign, bit_width); + return cast(op).GetFixedOutputRange(sign, + bit_width); }; } return spec; @@ -760,21 +758,21 @@ static bool IsStatsRedundant( Operation* op, const OpQuantSpecGetter op_quant_spec_getter, const OpQuantScaleSpecGetter op_quant_scale_spec_getter) { // If it has FixedOutputRangeInterface, no need to manually create spec. - return llvm::isa(op) || + return isa(op) || op_quant_scale_spec_getter(op)->has_fixed_output_range; } static bool IsSameScaleOp( Operation* op, const OpQuantScaleSpecGetter op_quant_scale_spec_getter) { // If it has SameScalesOpInterface, no need to manually create spec. - return llvm::dyn_cast(op) || + return dyn_cast(op) || op_quant_scale_spec_getter(op)->has_same_scale_requirement; } bool RemoveRedundantStatsOps( - mlir::func::FuncOp func, const OpQuantSpecGetter op_quant_spec_getter, + func::FuncOp func, const OpQuantSpecGetter op_quant_spec_getter, const OpQuantScaleSpecGetter op_quant_scale_spec_getter) { - llvm::SmallVector all_stats_ops; + SmallVector all_stats_ops; llvm::DenseSet redundant_stats_ops; // Step 0: remove the quantfork::StatisticsOp which are used by the @@ -782,8 +780,7 @@ bool RemoveRedundantStatsOps( // ops. func.walk([&](quantfork::QuantizeCastOp q) { auto input_op = q.getArg().getDefiningOp(); - if (auto stats = - llvm::dyn_cast_or_null(input_op)) { + if (auto stats = dyn_cast_or_null(input_op)) { q.setOperand(stats.getArg()); if (stats.use_empty()) stats.erase(); } @@ -820,8 +817,8 @@ bool RemoveRedundantStatsOps( if (!res.hasOneUse()) { continue; } - if (auto next_stats = llvm::dyn_cast( - *res.getUsers().begin())) { + if (auto next_stats = + dyn_cast(*res.getUsers().begin())) { // quantization parameters can be propagated to next_stats redundant_stats_ops.insert(next_stats); // add next_stats to the work list so propagation can continue. @@ -848,7 +845,7 @@ bool RemoveRedundantStatsOps( continue; } for (Value input : def->getOperands()) { - if (auto next_stats = llvm::dyn_cast_or_null( + if (auto next_stats = dyn_cast_or_null( input.getDefiningOp())) { redundant_stats_ops.insert(next_stats); all_stats_ops.push_back(next_stats); @@ -859,8 +856,8 @@ bool RemoveRedundantStatsOps( // Step3: Remove all the redundant stats ops for (Operation* it : redundant_stats_ops) { - if (!llvm::isa(it)) return true; - auto stats_op = llvm::cast(it); + if (!isa(it)) return true; + auto stats_op = cast(it); stats_op.getResult().replaceAllUsesWith(stats_op.getArg()); stats_op.erase(); } @@ -870,9 +867,9 @@ bool RemoveRedundantStatsOps( } LogicalResult VerifySameScales(Operation* op) { - auto same_scale_op = llvm::cast(op); + auto same_scale_op = cast(op); - llvm::SmallVector collected_quant_params; + SmallVector collected_quant_params; for (Value input : op->getOperands()) { QuantizedType quant_params = QuantizedType::getQuantizedElementType(input.getType()); @@ -901,9 +898,9 @@ LogicalResult VerifySameScales(Operation* op) { // method. if (!same_scale_op.RequiredSameQuantizedAxes()) { const auto expected_per_axis_qtype = - expected_params.dyn_cast(); + dyn_cast(expected_params); const auto compared_per_axis_qtype = - compared_params.dyn_cast(); + dyn_cast(compared_params); if (expected_per_axis_qtype && compared_per_axis_qtype && llvm::equal(expected_per_axis_qtype.getScales(), compared_per_axis_qtype.getScales()) && @@ -945,8 +942,8 @@ quant::UniformQuantizedType GetFixedOutputRange( const bool is_signed, const int bit_width, const Type tensor_type, const double scale, int64_t zero_point, int64_t storage_min, int64_t storage_max) { - const auto result_type = tensor_type.cast(); - if (!result_type.getElementType().isa()) return {}; + const auto result_type = cast(tensor_type); + if (!isa(result_type.getElementType())) return {}; Builder builder(result_type.getContext()); // Only support 8-bits and 16-bits @@ -988,17 +985,17 @@ Type ConvertSignedQuantizedToUnsigned(const Type signed_tensor_type, const auto flags = !quant::QuantizationFlags::Signed; QType new_qtype; - if (auto uqtype = qtype.dyn_cast()) { + if (auto uqtype = dyn_cast(qtype)) { new_qtype = quant::UniformQuantizedType::getChecked( loc, flags, qtype.getStorageType(), qtype.getExpressedType(), uqtype.getScale(), uqtype.getZeroPoint() - offset, uqtype.getStorageTypeMin() - offset, uqtype.getStorageTypeMax() - offset); } else if (auto aqtype = - qtype.dyn_cast()) { + dyn_cast(qtype)) { const auto zero_points = aqtype.getZeroPoints(); - llvm::SmallVector new_zero_points(zero_points.begin(), - zero_points.end()); + SmallVector new_zero_points(zero_points.begin(), + zero_points.end()); for (int i = 0; i < new_zero_points.size(); ++i) { new_zero_points[i] -= offset; } diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h index 453dc419371932..3f9f56d45fbaa7 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h @@ -493,6 +493,7 @@ class QuantizationPattern : public RewritePattern { continue; } + bool is_operand_or_result_modified = false; // Collect all the quantized inputs and "clone" the matched op by these // inputs. SmallVector inputs; @@ -517,6 +518,7 @@ class QuantizationPattern : public RewritePattern { // Dynamic range quantization is applied by having QuantizeOp as an // input. Only int8 weight is supported for now. inputs.push_back(dq_op.getOperand()); + is_operand_or_result_modified = true; } else { // Otherwise, it's the case where the operand is activations or the // quantizing_op is non-supported/weight-only. @@ -525,6 +527,7 @@ class QuantizationPattern : public RewritePattern { } else { if (auto dq_op = dyn_cast_or_null(operand.getDefiningOp())) { + is_operand_or_result_modified = true; inputs.push_back(dq_op.getOperand()); } else if (!ele_type.isF32()) { // If the operand is an integer tensor, then it doesn't require the @@ -561,6 +564,7 @@ class QuantizationPattern : public RewritePattern { outputs_replaced.insert( {user.getResult(), enumerated_result.index()}); output_types.push_back(user.getType()); + is_operand_or_result_modified = true; } else if (!result_ele_type.isF32()) { // If the result is an integer tensor, then it doesn't require the // D op in the pattern. @@ -576,6 +580,13 @@ class QuantizationPattern : public RewritePattern { } } + // For float16 quantization if none of the operand or result is modified, + // replacing the op. See b/335025403. + if (inference_type == tensorflow::DT_HALF && + !is_operand_or_result_modified) { + return failure(); + } + rewriter.setInsertionPointAfter(quantizing_op); OperationState new_state(quantizing_op->getLoc(), quantizing_op->getName().getStringRef(), inputs, diff --git a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc index a64ba201250727..7f66d76798acfa 100644 --- a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc @@ -94,12 +94,12 @@ bool IsStorageTypeI32(const QuantizedType quantized_type) { bool IsExpressedTypeF32(const QuantizedType quantized_type) { const Type expressed_type = quantized_type.getExpressedType(); - return expressed_type.isa(); + return mlir::isa(expressed_type); } bool IsI8F32UniformQuantizedType(const Type type) { const UniformQuantizedType quantized_type = - type.dyn_cast_or_null(); + mlir::dyn_cast_or_null(type); if (!quantized_type) { LLVM_DEBUG(llvm::dbgs() << "Expected a uniform quantized type. Got: " << type << ".\n"); @@ -123,7 +123,7 @@ bool IsI8F32UniformQuantizedType(const Type type) { bool IsI8F32UniformQuantizedPerAxisType(const Type type) { const UniformQuantizedPerAxisType quantized_per_axis_type = - type.dyn_cast_or_null(); + mlir::dyn_cast_or_null(type); if (!quantized_per_axis_type) { LLVM_DEBUG(llvm::dbgs() << "Expected a uniform quantized type. Got: " << type << ".\n"); @@ -147,7 +147,7 @@ bool IsI8F32UniformQuantizedPerAxisType(const Type type) { bool IsI32F32UniformQuantizedType(const Type type) { const UniformQuantizedType quantized_type = - type.dyn_cast_or_null(); + mlir::dyn_cast_or_null(type); if (!quantized_type) { LLVM_DEBUG(llvm::dbgs() << "Expected a uniform quantized type. Got: " << type << ".\n"); @@ -171,7 +171,7 @@ bool IsI32F32UniformQuantizedType(const Type type) { bool IsI32F32UniformQuantizedPerAxisType(const Type type) { const UniformQuantizedPerAxisType quantized_per_axis_type = - type.dyn_cast_or_null(); + mlir::dyn_cast_or_null(type); if (!quantized_per_axis_type) { LLVM_DEBUG(llvm::dbgs() << "Expected a uniform quantized type. Got: " << type << ".\n"); @@ -208,11 +208,11 @@ bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type) { } bool IsQuantizedTensorType(Type type) { - if (!type.isa()) { + if (!mlir::isa(type)) { return false; } - Type element_type = type.cast().getElementType(); - return element_type.isa(); + Type element_type = mlir::cast(type).getElementType(); + return mlir::isa(element_type); } bool IsOpFullyQuantized(Operation* op) { diff --git a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h index ab850c878ff0dd..e30db98a9616de 100644 --- a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h @@ -82,7 +82,7 @@ bool IsExpressedTypeF32(QuantizedType quantized_type); // Given a value, extract the `ElementType`. // `value` should be a non-null `TensorType`. inline Type GetElementType(const Value value) { - return value.getType().cast().getElementType(); + return mlir::cast(value.getType()).getElementType(); } // Returns true iff `type` is a uniform quantized type whose storage type is diff --git a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc index e9443a667fcef3..d4055b1732b1d8 100644 --- a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc @@ -348,7 +348,8 @@ TEST_F(IsI8F32UniformQuantizedTypeTest, UniformQuantizedTypeSucceeds) { /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); - EXPECT_THAT(qi8_type.dyn_cast_or_null(), NotNull()); + EXPECT_THAT(mlir::dyn_cast_or_null(qi8_type), + NotNull()); } TEST_F(IsI8F32UniformQuantizedTypeTest, StorageTypeI8Succeeds) { @@ -398,8 +399,9 @@ TEST_F(IsI8F32UniformQuantizedTypeTest, UniformQuantizedPerAxisTypeSucceeds) { /*scales=*/{1.0}, /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); - EXPECT_THAT(qi8_per_axis_type.dyn_cast_or_null(), - NotNull()); + EXPECT_THAT( + mlir::dyn_cast_or_null(qi8_per_axis_type), + NotNull()); } TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, StorageTypeI8Succeeds) { @@ -452,7 +454,8 @@ TEST_F(IsI32F32UniformQuantizedTypeTest, UniformQuantizedTypeSucceeds) { /*zeroPoint=*/0, /*storageTypeMin=*/-2147483647, /*storageTypeMax=*/2147483646); EXPECT_TRUE(IsI32F32UniformQuantizedType(qi32_type)); - EXPECT_THAT(qi32_type.dyn_cast_or_null(), NotNull()); + EXPECT_THAT(mlir::dyn_cast_or_null(qi32_type), + NotNull()); } TEST_F(IsI32F32UniformQuantizedTypeTest, StorageTypeI32Succeeds) { @@ -509,7 +512,7 @@ TEST_F(IsI32F32UniformQuantizedPerAxisTypeTest, /*storageTypeMax=*/127); EXPECT_FALSE(IsI32F32UniformQuantizedPerAxisType(qi8_type)); EXPECT_FALSE(IsStorageTypeI32(qi8_type)); - EXPECT_THAT(qi8_type.dyn_cast_or_null(), + EXPECT_THAT(mlir::dyn_cast_or_null(qi8_type), IsNull()); } @@ -523,7 +526,7 @@ TEST_F(IsI32F32UniformQuantizedTypeTest, UniformQuantizedPerAxisTypeSucceeds) { /*storageTypeMin=*/-2147483647, /*storageTypeMax=*/2147483646); EXPECT_THAT( - qi32_per_axis_type.dyn_cast_or_null(), + mlir::dyn_cast_or_null(qi32_per_axis_type), NotNull()); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 3da423119752cb..db53084419f1f9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -51,6 +51,7 @@ cc_library( "passes/convert_xla_call_module_op_to_bfloat16.cc", "passes/defer_activation_transpose.cc", "passes/fold_constant_transpose.cc", + "passes/insert_calibration_statistics_saver.cc", "passes/insert_weight_param.cc", "passes/lift_quantizable_spots_as_functions.cc", "passes/lift_quantizable_spots_as_functions_fusion.inc", @@ -99,6 +100,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:permutation", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:report", + "//tensorflow/compiler/mlir/quantization/stablehlo/instrumentations:save_report", "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", @@ -138,6 +140,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:regexp", "@local_tsl//tsl/platform:str_util", @@ -572,6 +575,7 @@ tf_cc_test( deps = [ ":math_utils", "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD index 5ae92d648bf5c9..f175dfdc9ea1a2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -40,6 +40,8 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings:string_view", ], ) @@ -62,7 +64,9 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) @@ -75,6 +79,7 @@ tf_cc_test( "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:status", @@ -110,10 +115,7 @@ cc_library( hdrs = ["debugger.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":graph_def", - "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", - "//tensorflow/core:protos_all_cc", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], @@ -321,9 +323,12 @@ cc_library( hdrs = ["report.h"], compatible_with = get_compatible_with_portable(), deps = [ + ":io", "//tensorflow/compiler/mlir/quantization/common:lift_as_function_call", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", @@ -337,13 +342,18 @@ tf_cc_test( name = "report_test", srcs = ["report_test.cc"], deps = [ + ":io", ":report", "//tensorflow/compiler/mlir/quantization/common:test_base", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status_matchers", ], ) @@ -365,8 +375,10 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":component", + ":config", ":pass_pipeline", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/instrumentations:save_report", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", "@com_google_absl//absl/base:nullability", @@ -436,12 +448,14 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":component", + ":config", ":context", ":pass_pipeline", ":saved_model_export", ":saved_model_import", ":types", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/instrumentations:save_report", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD index 3fbd4ed586e45f..9926546f8c47a8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD @@ -25,15 +25,22 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:min_max_value", "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc", - "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:statusor", ], ) @@ -45,6 +52,7 @@ cc_library( deps = [ ":representative_dataset", ":statistics", + "//tensorflow/compiler/mlir/quantization/stablehlo:passes", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:component", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:debugger", @@ -53,8 +61,11 @@ cc_library( "//tensorflow/compiler/mlir/quantization/stablehlo/cc:types", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", @@ -67,6 +78,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h index ffad37d15d243c..9e1950afa76dba 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h @@ -23,10 +23,6 @@ limitations under the License. namespace stablehlo::quantization { -// TODO: b/321158562 - Make the number of bins configurable. -// Default number of histogram bins for each batch sample. -constexpr int32_t kDefaultNumOfBins = 1 << 9; - // Calculates the bin width from the range and expected number of bins. The // bin width is formalized to the form of 2^n. As a consequence, the actual // number of bins might be smaller than the given `num_bins`. @@ -70,8 +66,10 @@ inline bool IsHistogramCalibration( } // Gets the number of bins for the given calibration method. -inline int32_t GetNumBins(const CalibrationOptions::CalibrationMethod method) { - return IsHistogramCalibration(method) ? kDefaultNumOfBins : 0; +inline int32_t GetNumBins(const CalibrationOptions& calib_opts) { + return IsHistogramCalibration(calib_opts.calibration_method()) + ? calib_opts.calibration_parameters().num_bins() + : 0; } } // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc index ce626145318b9f..52db906e512391 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/log/die_if_null.h" @@ -40,26 +41,67 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace mlir::quant::stablehlo { +namespace { using ::stablehlo::quantization::AddCalibrationStatistics; using ::stablehlo::quantization::CreateRepresentativeDatasetFileMap; using ::stablehlo::quantization::DisableDebugging; +using ::stablehlo::quantization::IsCalibrationRequired; using ::stablehlo::quantization::QuantizationConfig; +using ::stablehlo::quantization::ReadStatistics; using ::stablehlo::quantization::RepresentativeDatasetConfig; using ::stablehlo::quantization::io::CreateTmpDir; using ::stablehlo::quantization::io::GetLocalTmpFileName; +using ::stablehlo::quantization::io::ListDirectory; using ::tensorflow::AssetFileDef; using ::tensorflow::SignatureDef; +using ::tensorflow::calibrator::CalibrationStatistics; using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PyFunctionLibrary; +using ::tensorflow::quantization::RunPasses; +using CalibrationStatisticsFlatMap = + absl::flat_hash_map; + +} // namespace + +absl::Status RunCalibrationPasses( + mlir::ModuleOp module_op, MLIRContext& ctx, + absl::string_view calibration_data_dir, + const bool force_regenerate_calibration_data) { + // Disable DumpTensor ops when running calibration. + DisableDebugging(module_op); + + std::vector skipping_aggregator_ops; + if (!force_regenerate_calibration_data) { + TF_ASSIGN_OR_RETURN(const CalibrationStatisticsFlatMap statistics_map, + ReadStatistics(calibration_data_dir)); + absl::c_for_each(statistics_map, [&](const auto& iter) { + return skipping_aggregator_ops.push_back(iter.first); + }); + } + + return RunPasses( + /*name=*/ + CalibrationComponent::kName, + /*add_passes_func=*/ + [calibration_data_dir, &skipping_aggregator_ops](PassManager& pm) { + pm.addPass(CreateInsertCalibrationStatisticsSaverPass( + calibration_data_dir, skipping_aggregator_ops)); + }, + ctx, module_op); +} CalibrationComponent::CalibrationComponent( absl::Nonnull ctx, @@ -77,16 +119,23 @@ CalibrationComponent::CalibrationComponent( signature_def_map_(std::move(signature_def_map)), signature_keys_(std::move(signature_keys)) {} -absl::StatusOr CalibrationComponent::ExportToSavedModel( - ModuleOp module_op, const absl::string_view dst_saved_model_path) { +absl::Status CalibrationComponent::ExportToSavedModel( + ModuleOp module_op, absl::string_view calibration_data_dir, + const bool force_regenerate_calibration_data, + const absl::string_view dst_saved_model_path) { TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTmpFileName()); // Clone ModuleOp and function aliases so changes in this pipeline won't // be reflected in the original values. mlir::OwningOpRef cloned_module_ref(module_op.clone()); - // Disable DumpTensor ops when running calibration. - DisableDebugging(*cloned_module_ref); + TF_RETURN_IF_ERROR(RunCalibrationPasses(*cloned_module_ref, *ctx_, + calibration_data_dir, + force_regenerate_calibration_data)); + + const bool is_calibration_required = + IsCalibrationRequired(*cloned_module_ref); + if (!is_calibration_required) return absl::OkStatus(); // `duplicate_shape_determining_constants = false` because the // resulting graph of this step is not expected to be loaded on TPU. @@ -107,42 +156,52 @@ absl::StatusOr CalibrationComponent::ExportToSavedModel( src_saved_model_path_, tags_, signature_def_map_); - return exported_model; + return absl::OkStatus(); } absl::StatusOr CalibrationComponent::Run( ModuleOp module_op, const QuantizationConfig& config) { - // Exports the pre-calibrated model to SavedModel. - TF_ASSIGN_OR_RETURN(const std::string precalibrated_saved_model_dir, + // Export the calibration model to SavedModel. + TF_ASSIGN_OR_RETURN(const std::string calibration_saved_model_dir, CreateTmpDir()); - TF_ASSIGN_OR_RETURN( - ExportedModel exported_model, - ExportToSavedModel(module_op, precalibrated_saved_model_dir)); - - // Translates `RepresentativeDatasetConfig`s to signature key -> - // `RepresentativeDatasetFile` mapping. - const auto dataset_configs = - config.calibration_options().representative_datasets(); - const std::vector dataset_config_vector( - dataset_configs.begin(), dataset_configs.end()); - TF_ASSIGN_OR_RETURN( - const auto representative_dataset_file_map, - CreateRepresentativeDatasetFileMap(dataset_config_vector)); - - // Runs calibration on the exported model. The statistics will be stored in a - // separate singleton object `CalibratorSingleton` and are directly added to - // `exported_model` without re-importing it. - if (py_function_lib_->RunCalibration( - precalibrated_saved_model_dir, signature_keys_, tags_, - /*force_graph_mode_calibration=*/true, - representative_dataset_file_map) == std::nullopt) { - return absl::InternalError( - "CalibrationComponent error: Failed to run calibration."); + std::string calibration_data_dir = + config.calibration_options().calibration_data_dir(); + if (calibration_data_dir.empty()) { + TF_ASSIGN_OR_RETURN(calibration_data_dir, CreateTmpDir()); + } + + TF_RETURN_IF_ERROR(ExportToSavedModel( + module_op, calibration_data_dir, + config.calibration_options().force_regenerate_calibration_data(), + calibration_saved_model_dir)); + + TF_ASSIGN_OR_RETURN(std::vector calibration_saved_model_files, + ListDirectory(calibration_saved_model_dir)); + if (!calibration_saved_model_files.empty()) { + // Translate `RepresentativeDatasetConfig`s to signature key -> + // `RepresentativeDatasetFile` mapping. + const auto dataset_configs = + config.calibration_options().representative_datasets(); + const std::vector dataset_config_vector( + dataset_configs.begin(), dataset_configs.end()); + TF_ASSIGN_OR_RETURN( + const auto representative_dataset_file_map, + CreateRepresentativeDatasetFileMap(dataset_config_vector)); + + // Run calibration on the exported model. + if (py_function_lib_->RunCalibration( + calibration_saved_model_dir, signature_keys_, tags_, + /*force_graph_mode_calibration=*/true, + representative_dataset_file_map) == std::nullopt) { + return absl::InternalError( + "CalibrationComponent error: Failed to run calibration."); + } } if (absl::Status status = AddCalibrationStatistics( - module_op, config.calibration_options(), *py_function_lib_); + module_op, calibration_data_dir, config.calibration_options(), + *py_function_lib_); !status.ok()) { LOG(WARNING) << "Some CustomAggregator ops do not have min or max " "values. Parts of the graph are not quantized. " diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h index cb137031948a3a..03d2dd933732d4 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -76,9 +77,11 @@ class CalibrationComponent : public Component { // Exports `module_op` to SavedModel at `dst_saved_model_path`. This is used // to export the pre-calibrated `module_op` to SavedModel so that the // calibration process can use it to load and run the graph with the - // representative dataset. - absl::StatusOr ExportToSavedModel( - ModuleOp module_op, absl::string_view dst_saved_model_path); + // representative dataset. Returns a failure status if the export fails. + absl::Status ExportToSavedModel(ModuleOp module_op, + absl::string_view calibration_data_dir, + bool force_regenerate_calibration_data, + absl::string_view dst_saved_model_path); // Imports the SavedModel at `calibrated_saved_model_path` to `ModuleOp` after // running calibration. @@ -109,6 +112,11 @@ class CalibrationComponent : public Component { const std::vector signature_keys_; }; +// Runs passes to prepare the calibration model. +absl::Status RunCalibrationPasses(mlir::ModuleOp module_op, MLIRContext& ctx, + absl::string_view calibration_data_dir, + bool force_regenerate_calibration_data); + } // namespace mlir::quant::stablehlo #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_COMPONENT_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc index 19a44097458f1a..ea96bd029b079e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc @@ -15,39 +15,69 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h" #include +#include +#include +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/min_max_value.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tsl/platform/path.h" +#include "tsl/platform/statusor.h" namespace stablehlo::quantization { namespace { using ::stablehlo::quantization::CalibrationOptions; using ::tensorflow::calibrator::CalibrationStatistics; -using ::tensorflow::calibrator::CalibratorSingleton; +using ::tensorflow::calibrator::CalibrationStatisticsMap; using ::tensorflow::quantization::PyFunctionLibrary; +using CalibrationStatisticsFlatMap = + absl::flat_hash_map; } // namespace +// Reads the calibration statistics from the given directory. +absl::StatusOr ReadStatistics( + absl::string_view calibration_data_dir) { + TF_ASSIGN_OR_RETURN(std::vector statistics_files, + io::ListDirectory(calibration_data_dir)); + + CalibrationStatisticsFlatMap statistics_map; + for (const std::string& statistics_file : statistics_files) { + TF_ASSIGN_OR_RETURN( + const auto single_map, + io::ReadBinaryProto( + tsl::io::JoinPath(calibration_data_dir, statistics_file))); + statistics_map.insert(single_map.statistics().begin(), + single_map.statistics().end()); + } + return statistics_map; +} + absl::Status AddCalibrationStatistics( - mlir::ModuleOp module_op, const CalibrationOptions& calibration_options, + mlir::ModuleOp module_op, absl::string_view calibration_data_dir, + const CalibrationOptions& calibration_options, const PyFunctionLibrary& py_function_library) { + TF_ASSIGN_OR_RETURN(const CalibrationStatisticsFlatMap statistics_map, + ReadStatistics(calibration_data_dir)); + absl::Status status = absl::OkStatus(); - module_op.walk([&py_function_library, &calibration_options, - &status](mlir::TF::CustomAggregatorOp aggregator_op) { + module_op.walk([&py_function_library, &calibration_options, &status, + &statistics_map](mlir::TF::CustomAggregatorOp aggregator_op) { mlir::StringRef id = aggregator_op.getId(); - std::optional statistics = - CalibratorSingleton::GetStatistics(id); - if (statistics == std::nullopt) { + auto iter = statistics_map.find(id); + if (iter == statistics_map.end()) { status = absl::InternalError( absl::StrFormat("Calibrated data does not exist. Cannot find " "statistics. value for id: %s", @@ -56,10 +86,8 @@ absl::Status AddCalibrationStatistics( } const std::optional min_max_values = - py_function_library.GetCalibrationMinMaxValue(*statistics, + py_function_library.GetCalibrationMinMaxValue(iter->second, calibration_options); - CalibratorSingleton::ClearData(id); - if (min_max_values == std::nullopt) { status = absl::InternalError( "Cannot find min/max values for calibration statistics."); @@ -74,4 +102,14 @@ absl::Status AddCalibrationStatistics( return status; } +bool IsCalibrationRequired(mlir::ModuleOp module_op) { + bool calibration_required = false; + module_op.walk( + [&calibration_required]( + mlir::TF::CalibrationStatisticsSaverOp statistics_saver_op) { + calibration_required = true; + }); + return calibration_required; +} + } // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h index 9b67f22a2dac72..41f78be3578bca 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h @@ -15,22 +15,36 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_STATISTICS_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_STATISTICS_H_ +#include + +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" namespace stablehlo::quantization { +// Reads the calibration statistics from the given directory. +absl::StatusOr> +ReadStatistics(absl::string_view calibration_data_dir); + // Adds calibrated min / max values to CustomAggregator nodes in `graph_def`. // The min and max values will be added to the "min" and "max" attributes, // respectively. `calibration_options` provides the strategy to retrieve min and // max values. absl::Status AddCalibrationStatistics( - mlir::ModuleOp module_op, + mlir::ModuleOp module_op, absl::string_view calibration_data_dir, const stablehlo::quantization::CalibrationOptions& calibration_options, const tensorflow::quantization::PyFunctionLibrary& py_function_library); +// Checks if the model required calibration. +bool IsCalibrationRequired(mlir::ModuleOp module_op); + } // namespace stablehlo::quantization #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_STATISTICS_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc index b3aa1500a0a3c7..1522c68f300cba 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc @@ -29,59 +29,35 @@ void PopulateDefaultCalibrationOptions(QuantizationConfig& quant_config) { quant_config.mutable_calibration_options()->set_calibration_method( CalibrationOptions::CALIBRATION_METHOD_MIN_MAX); } + switch (quant_config.calibration_options().calibration_method()) { - case CalibrationOptions::CALIBRATION_METHOD_MIN_MAX: - break; - case CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX: - break; case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE: - if (quant_config.calibration_options() - .calibration_parameters() - .initial_num_bins() == 0) { - quant_config.mutable_calibration_options() - ->mutable_calibration_parameters() - ->set_initial_num_bins(256); - } - if (quant_config.calibration_options() - .calibration_parameters() - .min_percentile() == 0) { - quant_config.mutable_calibration_options() - ->mutable_calibration_parameters() - ->set_min_percentile(0.001); - } - if (quant_config.calibration_options() - .calibration_parameters() - .max_percentile() == 0) { - quant_config.mutable_calibration_options() - ->mutable_calibration_parameters() - ->set_max_percentile(99.999); - } - break; case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE: - if (quant_config.calibration_options() - .calibration_parameters() - .initial_num_bins() == 0) { - quant_config.mutable_calibration_options() - ->mutable_calibration_parameters() - ->set_initial_num_bins(256); - } - break; case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY: - if (quant_config.calibration_options() - .calibration_parameters() - .initial_num_bins() == 0) { - quant_config.mutable_calibration_options() - ->mutable_calibration_parameters() - ->set_initial_num_bins(256); - } - break; case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC: if (quant_config.calibration_options() .calibration_parameters() - .initial_num_bins() == 0) { + .num_bins() == 0) { quant_config.mutable_calibration_options() ->mutable_calibration_parameters() - ->set_initial_num_bins(256); + ->set_num_bins(512); + } + if (quant_config.calibration_options().calibration_method() == + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE) { + if (quant_config.calibration_options() + .calibration_parameters() + .min_percentile() == 0) { + quant_config.mutable_calibration_options() + ->mutable_calibration_parameters() + ->set_min_percentile(0.001); + } + if (quant_config.calibration_options() + .calibration_parameters() + .max_percentile() == 0) { + quant_config.mutable_calibration_options() + ->mutable_calibration_parameters() + ->set_max_percentile(99.999); + } } break; default: @@ -109,11 +85,18 @@ QuantizationSpec GetDefaultStaticRangePtqSpec(StaticRangePtqPreset preset) { return spec; } -QuantizationSpec GetDefaultWeightOnlyPtqSpec(WeightOnlyPtqPreset preset) { +QuantizationSpec GetDefaultWeightOnlyPtqSpec() { QuantizationSpec spec{}; spec.mutable_matcher()->mutable_function_name()->set_regex( "^.*(conv|dot_general).*"); - spec.mutable_method()->mutable_weight_only_ptq(); + + WeightOnlyPtq& weight_only_ptq_spec = + *spec.mutable_method()->mutable_weight_only_ptq(); + if (auto [iter, inserted] = + weight_only_ptq_spec.mutable_input_quantized_types()->try_emplace(1); + inserted) { + iter->second.mutable_dimension_specs(); + } return spec; } @@ -133,6 +116,9 @@ QuantizationSpec GetDefaultWeightOnlyPtqSpec(WeightOnlyPtqPreset preset) { // } QuantizationSpec GetPtqSpecForConvolution(Method::MethodCase method_case) { QuantizationSpec spec{}; + if (method_case != Method::kStaticRangePtq) { + return spec; + } // Matches all convolution quantizable unit family. spec.mutable_matcher()->mutable_function_name()->set_regex( @@ -147,18 +133,10 @@ QuantizationSpec GetPtqSpecForConvolution(Method::MethodCase method_case) { // The index of weight operands passed to lifted functions for convolution // is 1. - if (method_case == Method::kStaticRangePtq) { - StaticRangePtq& static_range_ptq_spec = - *spec.mutable_method()->mutable_static_range_ptq(); - static_range_ptq_spec.mutable_input_quantized_types()->try_emplace( - 1, std::move(conv_weight_quantized_type)); - } else if (method_case == Method::kWeightOnlyPtq) { - WeightOnlyPtq& weight_only_ptq_spec = - *spec.mutable_method()->mutable_weight_only_ptq(); - weight_only_ptq_spec.mutable_input_quantized_types()->try_emplace( - 1, std::move(conv_weight_quantized_type)); - } - + StaticRangePtq& static_range_ptq_spec = + *spec.mutable_method()->mutable_static_range_ptq(); + static_range_ptq_spec.mutable_input_quantized_types()->try_emplace( + 1, std::move(conv_weight_quantized_type)); return spec; }; @@ -192,15 +170,12 @@ void ExpandStaticRangePtqPreset(const StaticRangePtqPreset& preset, config.mutable_specs()->Swap(&new_specs); } -void ExpandWeightOnlyPtqPreset(const WeightOnlyPtqPreset& preset, - QuantizationConfig& config) { +void ExpandWeightOnlyPtqPreset(QuantizationConfig& config) { // Create a new `QuantizationSpecs` to replace the existing one. The // expansion from `WeightOnlyPtqPreset` gets populated first and then // user-provided explicit `QuantizationSpec`s will be appended. QuantizationSpecs new_specs{}; - *new_specs.add_specs() = - GetDefaultWeightOnlyPtqSpec(/*preset=*/config.weight_only_ptq_preset()); - // TODO: b/307625297 - Add per-channel weight only support. + *new_specs.add_specs() = GetDefaultWeightOnlyPtqSpec(); // Append user-provided specs to override existing specs. const QuantizationSpecs& previous_specs = config.specs(); @@ -222,7 +197,7 @@ QuantizationConfig ExpandPresets(const QuantizationConfig& config) { ExpandStaticRangePtqPreset(config.static_range_ptq_preset(), new_config); break; case QuantizationConfig::kWeightOnlyPtqPreset: - ExpandWeightOnlyPtqPreset(config.weight_only_ptq_preset(), new_config); + ExpandWeightOnlyPtqPreset(new_config); break; default: // Preset has not been specified. The expansion is a no-op. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h index 19f250bedfe1b8..f668cacd41ba2d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h @@ -15,6 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ +#include + +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" namespace stablehlo::quantization { @@ -45,6 +49,17 @@ QuantizationConfig ExpandPresets(const QuantizationConfig& config); bool HasQuantizationMethod(const QuantizationSpecs& specs, Method::MethodCase method_case); +// Convenience function for converting the optional `report_file_path` field to +// `std::optional`, where `std::nullopt` represents that the +// field is not explicitly set. The returned value is a reference type +// (`absl::string_view`) so its lifetime is bound to the input `config`. +inline std::optional GetReportFilePath( + const QuantizationConfig& config ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return config.has_report_file_path() + ? std::make_optional(config.report_file_path()) + : std::nullopt; +} + } // namespace stablehlo::quantization #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc index c46daaf1252f26..4662339f85624f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc @@ -24,6 +24,7 @@ namespace { using ::testing::Eq; using ::testing::SizeIs; using ::testing::StrEq; +using ::testing::Truly; TEST(PopulateDefaultsTest, PopulateDefaultsForEmptyConfig) { QuantizationConfig config{}; @@ -69,18 +70,16 @@ TEST(PopulateDefaultsTest, ExplicitCalibrationOptionsNotOverridden) { *config.mutable_calibration_options(); calibration_options.set_calibration_method( CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX); - calibration_options.mutable_calibration_parameters()->set_initial_num_bins( - 512); + calibration_options.mutable_calibration_parameters()->set_num_bins(512); // Test that if the user explicitly provided `calibration_options`, it is not // overridden. const QuantizationConfig new_config = PopulateDefaults(config); EXPECT_THAT(new_config.calibration_options().calibration_method(), Eq(CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX)); - EXPECT_THAT(new_config.calibration_options() - .calibration_parameters() - .initial_num_bins(), - Eq(512)); + EXPECT_THAT( + new_config.calibration_options().calibration_parameters().num_bins(), + Eq(512)); } TEST(PopulateDefaultsTest, DefaultNumbersPopulatedForPartOfCalibrationOptions) { @@ -89,18 +88,16 @@ TEST(PopulateDefaultsTest, DefaultNumbersPopulatedForPartOfCalibrationOptions) { *config.mutable_calibration_options(); calibration_options.set_calibration_method( CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE); - calibration_options.mutable_calibration_parameters()->set_initial_num_bins( - 512); + calibration_options.mutable_calibration_parameters()->set_num_bins(512); // Test that if the user explicitly provided part of the // `calibration_options`, it is not overridden, rest of the data are default. const QuantizationConfig new_config = PopulateDefaults(config); EXPECT_THAT(new_config.calibration_options().calibration_method(), Eq(CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE)); - EXPECT_THAT(new_config.calibration_options() - .calibration_parameters() - .initial_num_bins(), - Eq(512)); + EXPECT_THAT( + new_config.calibration_options().calibration_parameters().num_bins(), + Eq(512)); EXPECT_THAT(new_config.calibration_options() .calibration_parameters() .min_percentile(), @@ -123,10 +120,9 @@ TEST(PopulateDefaultsTest, EXPECT_THAT( new_config.calibration_options().calibration_method(), Eq(CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE)); - EXPECT_THAT(new_config.calibration_options() - .calibration_parameters() - .initial_num_bins(), - Eq(256)); + EXPECT_THAT( + new_config.calibration_options().calibration_parameters().num_bins(), + Eq(512)); EXPECT_THAT(new_config.calibration_options() .calibration_parameters() .min_percentile(), @@ -171,10 +167,12 @@ TEST(ExpandPresetsTest, ExpandStaticRangePtqEnableFullIntquantization) { const StaticRangePtq& srq_spec = conv_spec.method().static_range_ptq(); ASSERT_THAT(srq_spec.input_quantized_types(), SizeIs(1)); ASSERT_TRUE(srq_spec.input_quantized_types().contains(1)); + ASSERT_TRUE(srq_spec.input_quantized_types().at(1).has_dimension_specs()); - EXPECT_THAT( - srq_spec.input_quantized_types().at(1).dimension_specs().dimension(), - Eq(3)); + const QuantizedDimension& dimension_specs = + srq_spec.input_quantized_types().at(1).dimension_specs(); + ASSERT_TRUE(dimension_specs.has_dimension()); + EXPECT_THAT(dimension_specs.dimension(), Eq(3)); // Test that representative dataset config has been transferred to the // `CalibrationOptions`. @@ -285,6 +283,15 @@ TEST(ExpandPresetsTest, ExpandWeightOnlyPtqPresetDefault) { EXPECT_THAT(spec.matcher().function_name().regex(), StrEq("^.*(conv|dot_general).*")); EXPECT_TRUE(spec.method().has_weight_only_ptq()); + + const WeightOnlyPtq& weight_only_ptq_spec = spec.method().weight_only_ptq(); + + EXPECT_THAT(weight_only_ptq_spec.input_quantized_types(), + UnorderedElementsAre(Pair( + 1, Truly([](const auto& quantized_type) { + return quantized_type.has_dimension_specs() && + !quantized_type.dimension_specs().has_dimension(); + })))); } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc index a06c7f8ed79fb4..3a6a30a4105d4b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc @@ -16,11 +16,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" namespace stablehlo::quantization { @@ -29,15 +25,6 @@ void DisableDebugging(mlir::ModuleOp module_op) { [](mlir::TF::DumpTensorOp dump_op) { dump_op.setEnabled(false); }); } -void EnableDebugging(tensorflow::quantization::ExportedModel& exported_model) { - MutateNodeDefs(*exported_model.mutable_graph_def(), - [](tensorflow::NodeDef& node_def) { - if (node_def.op() == "DumpTensor") { - (*node_def.mutable_attr())["enabled"].set_b(true); - } - }); -} - void ChangeToQuantizedFilename(mlir::ModuleOp module_op) { module_op.walk([](mlir::TF::DumpTensorOp dump_op) { dump_op.setFileName("quantized_tensor_data.pb"); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h index f034e4d94ee4bf..feae14446c8515 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h @@ -16,16 +16,12 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_DEBUGGER_H_ #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" namespace stablehlo::quantization { // Disables debugging on `DumpTensor` ops. void DisableDebugging(mlir::ModuleOp module_op); -// Enables debugging on `DumpTensor` ops. -void EnableDebugging(tensorflow::quantization::ExportedModel& exported_model); - // Changes the filename from `unquantized_tensor_data.pb` to // `quantized_tensor_data.pb`. void ChangeToQuantizedFilename(mlir::ModuleOp module_op); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.cc index 16a1013ae25166..94aa9ef780a522 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.cc @@ -15,11 +15,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "tsl/platform/env.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace stablehlo::quantization::io { @@ -53,4 +56,32 @@ absl::StatusOr CreateTmpDir() { return CreateTmpDir(tsl::Env::Default()); } +absl::Status WriteStringToFile(const absl::string_view file_path, + const absl::string_view data) { + auto* env = tsl::Env::Default(); + return WriteStringToFile(env, std::string(file_path), data); +} + +absl::StatusOr ReadFileToString( + const absl::string_view file_path) { + auto* env = tsl::Env::Default(); + std::string data{}; + absl::Status read_status = + ReadFileToString(env, std::string(file_path), &data); + + if (read_status.ok()) { + return data; + } else { + return read_status; + } +} + +absl::StatusOr> ListDirectory( + absl::string_view directory) { + std::vector children; + TF_RETURN_IF_ERROR( + tsl::Env::Default()->GetChildren(std::string(directory), &children)); + return children; +} + } // namespace stablehlo::quantization::io diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h index bf17ba641f9da5..39c99436e361b3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h @@ -16,9 +16,13 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_IO_H_ #include +#include +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "tsl/platform/env.h" +#include "tsl/platform/errors.h" namespace stablehlo::quantization::io { @@ -41,6 +45,29 @@ absl::StatusOr CreateTmpDir(tsl::Env* env); // returned by `tsl::Env::Default`. absl::StatusOr CreateTmpDir(); +// Convenience function for writing string `data` to file without the need to +// pass `tsl::Env` instance. Internally it uses the default `tsl::Env::Default`. +absl::Status WriteStringToFile(absl::string_view file_path, + absl::string_view data); + +// Convenience function for reading string data from file at `file_path` without +// the need to pass `tsl::Env` instance. Internally it uses the default +// `tsl::Env::Default`. Returns an OK status with string data containing file +// contents. Returns non-ok status upon error, e.g. file doesn't exist. +absl::StatusOr ReadFileToString(absl::string_view file_path); + +// Lists all files and directories under the given directory. +absl::StatusOr> ListDirectory( + absl::string_view directory); + +template +absl::StatusOr ReadBinaryProto(const std::string& binary_file_path) { + MessageT message; + TF_RETURN_IF_ERROR( + tsl::ReadBinaryProto(tsl::Env::Default(), binary_file_path, &message)); + return message; +} + } // namespace stablehlo::quantization::io #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_IO_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc index f4f1c5c16589e4..180df43a62a249 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include +#include #include #include @@ -23,18 +24,21 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "tsl/platform/env.h" #include "tsl/platform/file_system.h" -#include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/types.h" namespace stablehlo::quantization::io { namespace { +using ::testing::Eq; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::Not; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; using ::tsl::testing::IsOk; using ::tsl::testing::StatusIs; @@ -140,5 +144,63 @@ TEST(IoTest, CreateTmpDirWhenInvalidPathReturnsInternalError) { HasSubstr("Failed to create tmp dir"))); } +TEST(IoTest, WriteStringToFile) { + const std::string dst_file_path = + absl::StrCat(testing::TempDir(), "/tmp_file"); + + const absl::Status write_status = + WriteStringToFile(dst_file_path, "test_string"); + ASSERT_THAT(write_status, IsOk()); + + auto* const env = tsl::Env::Default(); + ASSERT_THAT(env->FileExists(dst_file_path), IsOk()); + + std::string data{}; + ASSERT_THAT(tsl::ReadFileToString(env, dst_file_path, &data), IsOk()); + + EXPECT_THAT(data, Eq("test_string")); +} + +TEST(IoTest, ReadFileToString) { + // Prepare a temp file and write some string to it. + const std::string src_file_path = + absl::StrCat(testing::TempDir(), "/tmp_file"); + + { + std::ofstream ofs(src_file_path); + ofs << "test_string"; + } + + // Test that the contents match. + const absl::StatusOr read_status = + ReadFileToString(src_file_path); + ASSERT_THAT(read_status, IsOk()); + EXPECT_THAT(*read_status, Eq("test_string")); +} + +TEST(IoTest, ListChildrenInDirectory) { + absl::StatusOr tmp_dir = CreateTmpDir(); + + ASSERT_THAT(tmp_dir, IsOk()); + + auto* const env = tsl::Env::Default(); + EXPECT_THAT(env->FileExists(*tmp_dir), IsOk()); + + ASSERT_THAT( + WriteStringToFile(absl::StrCat(*tmp_dir, "/tmp_file1"), "test_string"), + IsOk()); + ASSERT_THAT( + WriteStringToFile(absl::StrCat(*tmp_dir, "/tmp_file2"), "test_string"), + IsOk()); + ASSERT_THAT(env->RecursivelyCreateDir(absl::StrCat(*tmp_dir, "/subdir")), + IsOk()); + + absl::StatusOr> children = ListDirectory(*tmp_dir); + EXPECT_THAT(children, IsOk()); + EXPECT_THAT(children.value(), SizeIs(3)); + EXPECT_THAT(children.value(), + UnorderedElementsAre("subdir", "tmp_file1", "tmp_file2")); +} + } // namespace } // namespace stablehlo::quantization::io diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc index 622ff502c01ed9..490a9290c8342b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc @@ -32,7 +32,6 @@ using ::stablehlo::quantization::CalibrationOptions; using ::stablehlo::quantization::DebuggerConfig; using ::stablehlo::quantization::PipelineConfig; using ::stablehlo::quantization::QuantizationSpecs; -using ::stablehlo::quantization::StaticRangePtqPreset; void AddPreCalibrationPasses(OpPassManager& pm, const CalibrationOptions& calibration_options, @@ -51,7 +50,6 @@ void AddPreCalibrationPasses(OpPassManager& pm, } pm.addNestedPass( CreateInsertCustomAggregationOpsPass(calibration_options)); - pm.addPass(CreateIssueIDsOfCustomAggregationOpsPass()); } void AddPostCalibrationPasses(OpPassManager& pm, @@ -64,7 +62,6 @@ void AddPostCalibrationPasses(OpPassManager& pm, options.enable_per_channel_quantized_weight_ = true; // For debugging purposes. options.mlir_dump_file_name_ = "quantize_composite_functions"; - options.enable_weight_only_ = false; options.merge_fusion_with_dequantize_ = pipeline_config.merge_fusion_with_dequantize(); @@ -101,7 +98,6 @@ void AddWeightOnlyQuantizationPasses( QuantizeCompositeFunctionsPassOptions options; // For debugging purposes. options.mlir_dump_file_name_ = "quantize_composite_functions"; - options.enable_weight_only_ = true; pm.addPass(createQuantizeCompositeFunctionsPass(options)); // Add an inliner pass to inline quantized StableHLO functions. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc index 001ece707cfe90..d164a8e07617e9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc @@ -14,12 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h" +#include + #include "absl/base/nullability.h" #include "absl/log/die_if_null.h" #include "absl/status/statusor.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -27,6 +31,7 @@ limitations under the License. namespace mlir::quant::stablehlo { +using ::stablehlo::quantization::GetReportFilePath; using ::stablehlo::quantization::PipelineConfig; using ::stablehlo::quantization::QuantizationConfig; using ::stablehlo::quantization::QuantizationSpecs; @@ -41,6 +46,11 @@ absl::StatusOr PostCalibrationComponent::Run( TF_RETURN_IF_ERROR(RunPasses( kName, /*add_passes_func=*/ [&config](PassManager& pm) { + // Add instrumentation to save quantization report after quantization. + pm.addInstrumentation( + std::make_unique( + GetReportFilePath(config))); + AddPostCalibrationPasses(pm, config.pipeline_config(), config.specs()); }, *ctx_, module_op)); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc index 93be3516d76f8d..f8181deca51a0e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -26,6 +28,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep @@ -33,8 +36,10 @@ limitations under the License. namespace mlir::quant::stablehlo { namespace { +using ::stablehlo::quantization::Method; using ::stablehlo::quantization::QuantizationResult; using ::stablehlo::quantization::QuantizationResults; +using ::stablehlo::quantization::io::WriteStringToFile; using ::tsl::protobuf::TextFormat; // Given a `quantized_func_name` that starts with `kQuantizedFuncPrefix`, @@ -48,19 +53,27 @@ std::string GetCompositeFunctionName(const StringRef quantized_func_name) { // Retrieves `QuantizationResult` from `call_op`. If the callee's name starts // with `kQuantizedFuncPrefix` then a `QuantizationResult` will be returned with // its `name` field set to the callee's name reverted back to the lifted -// function's name. Otherwise, returns `std::nullopt`. +// function's name. Also, `call_op` must have the `kQuantizationMethodAttr` +// attribute, which is deserialized as `Method` and set in the returned +// `QuantizationResult`. Otherwise, it returns `std::nullopt`. std::optional GetQuantizationResult(func::CallOp call_op) { const StringRef callee_name = call_op.getCalleeAttr().getValue(); + if (!callee_name.starts_with(kQuantizedFuncPrefix)) { + return std::nullopt; // `call_op` is not a quantized function call. + } - if (callee_name.starts_with(kQuantizedFuncPrefix)) { - // TODO: b/329554870 - Transfer the `Method` used to quantize the op. - QuantizationResult result{}; - result.mutable_quantizable_unit()->set_name( - GetCompositeFunctionName(callee_name)); - return result; - } else { + absl::StatusOr method = GetQuantizationMethod(call_op); + if (!method.ok()) { + call_op->emitError() << "Failed to get quantization method: " + << method.status().ToString(); return std::nullopt; } + + QuantizationResult result{}; + result.mutable_quantizable_unit()->set_name( + GetCompositeFunctionName(callee_name)); + *result.mutable_method() = std::move(*method); + return result; } // Retrieves `QuantizationResult` from `xla_call_module_op`. If @@ -72,9 +85,8 @@ std::optional GetQuantizationResult(func::CallOp call_op) { std::optional GetQuantizationResult( TF::XlaCallModuleOp xla_call_module_op) { const StringAttr callee_name_attr = - xla_call_module_op - ->getDiscardableAttr(kOriginalStablehloEntryFunctionAttrName) - .dyn_cast_or_null(); + mlir::dyn_cast_or_null(xla_call_module_op->getDiscardableAttr( + kOriginalStablehloEntryFunctionAttrName)); // `TF::XlaCallModuleOp` without the `_original_entry_function` means it is // not a quantizable unit. @@ -152,4 +164,11 @@ void QuantizationReport::Print() const { llvm::outs().flush(); // Show the report immediately. } +absl::Status QuantizationReport::Save(const StringRef file_path) const { + std::string results_str{}; + TextFormat::PrintToString(GetQuantizationResults(), &results_str); + + return WriteStringToFile(file_path, results_str); +} + } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h index a362bb758cb60c..8252dda620dc3e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h @@ -17,7 +17,9 @@ limitations under the License. #include +#include "absl/status/status.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" namespace mlir::quant::stablehlo { @@ -50,6 +52,11 @@ class QuantizationReport { // Prints a human-readable report to stdout. void Print() const; + // Saves the report to `file_path`. The textproto representation of + // `QuantizationResults` will be written to the file. Returns non-ok status + // when the file write fails. + absl::Status Save(StringRef file_path) const; + private: ::stablehlo::quantization::QuantizationResults CollectResultsFromModuleOp( ModuleOp module_op) const; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc index 4783fb6beebc2d..690ee47e5b3c7d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc @@ -19,12 +19,17 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/test_base.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status_matchers.h" namespace mlir::quant::stablehlo { namespace { @@ -33,10 +38,14 @@ using ::stablehlo::quantization::Method; using ::stablehlo::quantization::QuantizableUnit; using ::stablehlo::quantization::QuantizationResult; using ::stablehlo::quantization::QuantizationResults; +using ::stablehlo::quantization::io::ReadFileToString; +using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::SizeIs; using ::testing::StrEq; +using ::testing::TempDir; using ::tsl::protobuf::TextFormat; +using ::tsl::testing::IsOk; using QuantizationReportTest = ::mlir::quant::QuantizationTestBase; @@ -74,7 +83,7 @@ TEST_F(QuantizationReportTest, InitializeWithModuleOp) { func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { %0 = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>> %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> - %2 = call @quantized_dot_general_fn(%1, %0) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> + %2 = call @quantized_dot_general_fn(%1, %0) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> return %3 : tensor<1x3xf32> } @@ -96,11 +105,73 @@ TEST_F(QuantizationReportTest, InitializeWithModuleOp) { // Test that the quantized `QuantizableUnit` corresponding to // `composite_dot_general_fn` is captured. - // TODO: Transfer the `Method` used to quantize the op. const QuantizationResult& result = results.results(0); EXPECT_THAT(result.quantizable_unit().name(), StrEq("composite_dot_general_fn")); - EXPECT_FALSE(result.has_method()); + EXPECT_TRUE(result.method().has_static_range_ptq()); +} + +TEST_F(QuantizationReportTest, + InitializeWithModuleOpWithoutQuantizationMethodAttribute) { + // A quantized dot_general op but the `CallOp` is missing the + // `_quantization_method` attribute. + constexpr absl::string_view + kQuantizedDotGeneralMissingQuantizationMethodAttr = R"mlir( + func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %2 = call @quantized_dot_general_fn(%1, %0) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + func.func private @quantized_dot_general_fn(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kQuantizedDotGeneralMissingQuantizationMethodAttr); + ASSERT_TRUE(module_op); + + const QuantizationReport report(*module_op); + const QuantizationResults& results = report.GetQuantizationResults(); + // The quantized call op without the _quantization_method attribute is not + // captured as a `QuantizationResult`. + ASSERT_THAT(results.results(), IsEmpty()); +} + +TEST_F(QuantizationReportTest, InitializeWithModuleOpWithInvalidCalleeName) { + // A quantized dot_general op but the callee function has an invalid name. It + // is expected to start with `quantized_`. + constexpr absl::string_view kQuantizedDotGeneralWithInvalidCalleeName = + R"mlir( + func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %2 = call @invalid_quantized_dot_general_fn(%1, %0) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + func.func private @invalid_quantized_dot_general_fn(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kQuantizedDotGeneralWithInvalidCalleeName); + ASSERT_TRUE(module_op); + + const QuantizationReport report(*module_op); + const QuantizationResults& results = report.GetQuantizationResults(); + // The quantized call op whose callee doesn't start with `quantized_` is not + // captured as a `QuantizationResult`. + ASSERT_THAT(results.results(), IsEmpty()); } TEST_F(QuantizationReportTest, InitializeWithModuleOpWithNonQuantizedOp) { @@ -141,11 +212,11 @@ TEST_F(QuantizationReportTest, func.func @main(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x3xf32> { // Non-quantized dot_general. %0 = stablehlo.constant dense<3.000000e+0> : tensor<2x3xf32> - %1 = "tf.XlaCallModule"(%arg0, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> // Quantized dot_general. %2 = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>> %3 = stablehlo.uniform_quantize %arg1 : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> - %4 = call @quantized_dot_general_fn_2(%3, %2) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> + %4 = call @quantized_dot_general_fn_2(%3, %2) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> %5 = stablehlo.uniform_dequantize %4 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> // Add is there to prevent from dot_generals from being DCEed. %6 = stablehlo.add %1, %5 : tensor<1x3xf32> @@ -178,7 +249,7 @@ TEST_F(QuantizationReportTest, const QuantizationResult& quantized_result = results.results(0); EXPECT_THAT(quantized_result.quantizable_unit().name(), StrEq("composite_dot_general_fn_2")); - EXPECT_FALSE(quantized_result.has_method()); + EXPECT_TRUE(quantized_result.method().has_static_range_ptq()); // Test that the non-quantized op is captured in `results`. const QuantizationResult& non_quantized_result = results.results(1); @@ -203,9 +274,52 @@ TEST_F(QuantizationReportTest, ToString) { std::string result_str{}; TextFormat::PrintToString(report.GetQuantizationResults(), &result_str); - EXPECT_THAT(report.ToString(), testing::HasSubstr("Quantization Report")); - EXPECT_THAT(report.ToString(), testing::HasSubstr(result_str)); - EXPECT_THAT(report.ToString(), testing::HasSubstr("Quantization Report End")); + EXPECT_THAT(report.ToString(), HasSubstr("Quantization Report")); + EXPECT_THAT(report.ToString(), HasSubstr(result_str)); + EXPECT_THAT(report.ToString(), HasSubstr("Quantization Report End")); +} + +TEST_F(QuantizationReportTest, Save) { + constexpr absl::string_view kQuantizedDotGeneral = R"mlir( + func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %2 = call @quantized_dot_general_fn(%1, %0) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + func.func private @quantized_dot_general_fn(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kQuantizedDotGeneral); + ASSERT_TRUE(module_op); + + const QuantizationReport report(*module_op); + + const std::string dst_file_path = + absl::StrCat(TempDir(), "/quantization_report.txtpb"); + const absl::Status save_status = report.Save(dst_file_path); + ASSERT_THAT(save_status, IsOk()); + + const absl::StatusOr file_data = ReadFileToString(dst_file_path); + ASSERT_THAT(file_data, IsOk()); + + // Test that the file data can be parsed as `QuantizationResults`. + QuantizationResults results{}; + ASSERT_TRUE(TextFormat::ParseFromString(*file_data, &results)); + + // Check that `results` reflects the information of the quantized units + // properly. + ASSERT_THAT(results.results(), SizeIs(1)); + EXPECT_THAT(results.results(0).quantizable_unit().name(), + StrEq("composite_dot_general_fn")); + EXPECT_TRUE(results.results(0).method().has_static_range_ptq()); } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.cc index a223a0b03f58a4..295ab06eb1bf70 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.cc @@ -77,7 +77,7 @@ absl::StatusOr SavedModelToMlirModuleOp( module_op.status().ToString())); } - return std::make_pair(module_op->release(), std::move(bundle)); + return std::make_pair(std::move(*module_op), std::move(bundle)); } absl::StatusOr> @@ -119,7 +119,7 @@ void UpdateFunctionAliases( }); } -absl::StatusOr ImportSavedModel( +absl::StatusOr> ImportSavedModel( const absl::string_view saved_model_path, const std::vector& signature_keys, const std::unordered_set& tags, @@ -132,7 +132,7 @@ absl::StatusOr ImportSavedModel( SavedModelToMlirModuleOp(saved_model_path, tags, signature_keys, ctx)); auto [module_op, saved_model_bundle] = std::move(imported_module); - UpdateFunctionAliases(function_aliases, module_op); + UpdateFunctionAliases(function_aliases, *module_op); // Collect the names of the functions that have aliases so that they may not // be inlined. @@ -143,11 +143,11 @@ absl::StatusOr ImportSavedModel( TF_RETURN_IF_ERROR(PreprocessAndFreezeGraph( mlir_dump_file_prefix, /*is_inliner_run=*/true, - /*noinline_functions=*/aliased_function_names, module_op, &ctx, + /*noinline_functions=*/aliased_function_names, *module_op, &ctx, saved_model_bundle == nullptr ? nullptr : saved_model_bundle->GetSession(), /*run_tf_to_stablehlo=*/true, /*deserialize_xla_call_module=*/false)); - return module_op; + return std::move(module_op); } } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h index 631d2e714900aa..8f1e4236e09823 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" @@ -38,7 +39,8 @@ namespace mlir::quant::stablehlo { // `tensorflow::Session` which may be useful when reading values from resources // (e.g. `TF::VarHandleOp`s). using ImportedMlirModuleOp = - std::pair>; + std::pair, + std::unique_ptr<::tensorflow::SavedModelBundle>>; // Loads a SavedModel at `saved_model_path` and converts it to `mlir::ModuleOp`. // @@ -72,7 +74,7 @@ void UpdateFunctionAliases( // Loads a SavedModel to `mlir::ModuleOp` and performs preprocesses including // shape inference and graph freezing. // TODO: b/329206105 - Add unit tests after decomposing preprocessing passes. -absl::StatusOr ImportSavedModel( +absl::StatusOr> ImportSavedModel( absl::string_view saved_model_path, const std::vector& signature_keys, const std::unordered_set& tags, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc index 015ab7605a05b7..3d350613629c7e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h" @@ -105,7 +106,7 @@ absl::Status QuantizeStaticRangePtq( } TF_ASSIGN_OR_RETURN( - ModuleOp module_op, + OwningOpRef module, ImportSavedModel(src_saved_model_path, signature_keys, tags, quantization_config, PreCalibrationComponent::kName, *function_aliases, *ctx)); @@ -113,14 +114,14 @@ absl::Status QuantizeStaticRangePtq( StaticRangePtqComponent static_range_ptq_component( ctx.get(), &py_function_library, src_saved_model_path, signature_keys, tags, signature_def_map, *function_aliases); - TF_ASSIGN_OR_RETURN(module_op, static_range_ptq_component.Run( - module_op, quantization_config)); + TF_ASSIGN_OR_RETURN( + *module, static_range_ptq_component.Run(*module, quantization_config)); TF_ASSIGN_OR_RETURN( const ExportedModel post_calibrated_exported_model, CreateExportedModel(signature_keys, tags, quantization_config, PostCalibrationComponent::kName, *function_aliases, - *ctx, module_op)); + *ctx, *module)); // Remove the `tpu` tag for exporting because the output quantized model is // essentially a CPU model. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc index bbd9a9c25620bd..f1df09c36ccce0 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc @@ -28,11 +28,13 @@ limitations under the License. #include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/context.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" @@ -42,6 +44,7 @@ limitations under the License. namespace mlir::quant::stablehlo { +using ::stablehlo::quantization::GetReportFilePath; using ::stablehlo::quantization::QuantizationConfig; using ::tensorflow::SignatureDef; using ::tensorflow::quantization::ExportedModel; @@ -56,6 +59,11 @@ absl::StatusOr WeightOnlyPtqComponent::Run( TF_RETURN_IF_ERROR(RunPasses( kName, /*add_passes_func=*/ [&config](PassManager& pm) { + // Add instrumentation to save quantization report after quantization. + pm.addInstrumentation( + std::make_unique( + GetReportFilePath(config))); + AddWeightOnlyQuantizationPasses(pm, config.specs(), config.pipeline_config(), config.debugger_config()); @@ -85,20 +93,20 @@ absl::Status QuantizeWeightOnlyPtq( } TF_ASSIGN_OR_RETURN( - ModuleOp module_op, + auto module, ImportSavedModel(src_saved_model_path, signature_keys, tags, quantization_config, WeightOnlyPtqComponent::kName, *function_aliases, *ctx)); WeightOnlyPtqComponent weight_only_ptq_component(ctx.get()); TF_ASSIGN_OR_RETURN( - module_op, weight_only_ptq_component.Run(module_op, quantization_config)); + *module, weight_only_ptq_component.Run(*module, quantization_config)); TF_ASSIGN_OR_RETURN( const ExportedModel post_calibrated_exported_model, CreateExportedModel(signature_keys, tags, quantization_config, WeightOnlyPtqComponent::kName, *function_aliases, - *ctx, module_op)); + *ctx, *module)); // Remove the `tpu` tag for exporting because the output quantized model is // essentially a CPU model. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/BUILD new file mode 100644 index 00000000000000..476192965a3f8f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/BUILD @@ -0,0 +1,46 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["@stablehlo//:license"], + default_visibility = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:__subpackages__", + ], + licenses = ["notice"], +) + +cc_library( + name = "save_report", + srcs = ["save_report.cc"], + hdrs = ["save_report.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:report", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "save_report_test", + srcs = ["save_report_test.cc"], + deps = [ + ":save_report", + "//tensorflow/compiler/mlir/quantization/common:test_base", + "//tensorflow/compiler/mlir/quantization/stablehlo:passes", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status_matchers", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.cc b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.cc new file mode 100644 index 00000000000000..e1a705cdbb24f6 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.cc @@ -0,0 +1,95 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h" + +namespace mlir::quant::stablehlo { +namespace { + +// Converts `std::optional` to `std::optional`. +// A `std::nullopt` is returned when `view` is `std::nullopt`. +std::optional OptionalStringViewToOptionalString( + std::optional view) { + if (view == std::nullopt) return std::nullopt; + return std::make_optional(*view); +} + +// Whether the pass is `QuantizeCompositeFunctionPass`. +bool IsQuantizeCompositeFunctionPass(absl::Nullable pass, + absl::Nullable op) { + // It is known that `op` is `ModuleOp` when `pass` is + // `QuantizeCompositeFunctionPass`, but the check is still performed to be + // defensive. + return pass != nullptr && + pass->getArgument() == "stablehlo-quantize-composite-functions" && + isa_and_nonnull(op); +} + +// Report is saved only when: +// * After running `QuantizeCompositeFunctionPass`. +// * The pass is run on `ModuleOp`. +// * `file_path` is not `nullopt`. +bool ShouldSaveReport(absl::Nullable pass, absl::Nullable op, + const std::optional& file_path) { + return file_path != std::nullopt && IsQuantizeCompositeFunctionPass(pass, op); +} + +void SaveReport(const QuantizationReport& report, + const absl::string_view file_path) { + if (const absl::Status save_status = report.Save(file_path); + save_status.ok()) { + LOG(INFO) << "Successfully saved quantization report to: " << file_path; + } else { + LOG(ERROR) << "Failed to save quantization report to: " << file_path + << " with status: " << save_status; + } +} + +} // namespace + +SaveQuantizationReportInstrumentation::SaveQuantizationReportInstrumentation( + std::optional file_path) + : file_path_(OptionalStringViewToOptionalString(file_path)) {} + +void SaveQuantizationReportInstrumentation::runAfterPass(Pass* pass, + Operation* op) { + // Only run after `QuantizeCompositeFunctionPass`. + if (!IsQuantizeCompositeFunctionPass(pass, op)) return; + + auto module_op = cast(op); + const QuantizationReport report(module_op); + + // Print a human-readable report to stdout regardless of whether the report + // is saved to file. + report.Print(); + + // Exit early if the report should not be saved to file. + if (!ShouldSaveReport(pass, op, file_path_)) return; + + SaveReport(report, *file_path_); +} + +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.h b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.h new file mode 100644 index 00000000000000..e690e6252b3393 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.h @@ -0,0 +1,52 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_INSTRUMENTATIONS_SAVE_REPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_INSTRUMENTATIONS_SAVE_REPORT_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassInstrumentation.h" // from @llvm-project + +namespace mlir::quant::stablehlo { + +// A `PassInstrumentation` that saves quantization report to file after +// `QuantizeCompositeFunctionsPass` is run. It inspects the `ModuleOp` after +// quantization and analyzes the quantizable units and quantization methods +// used. The report file will be saved at the `file_path`. The report file +// contains textproto of `QuantizationResults`. `file_path`'s base directories +// should exist (this pass instrumentation will not `mkdir` them). +// +// See `QuantizationReport` for further details on the quantization report. +class SaveQuantizationReportInstrumentation : public PassInstrumentation { + public: + // `file_path` is the path to save the report file. The report file is in + // textproto format so a `.txtpb` extension is preferred but it doesn't result + // in error if other extension is used. This instrumentation will not be run + // if `file_path` is a `nullopt`. + explicit SaveQuantizationReportInstrumentation( + std::optional file_path); + + void runAfterPass(Pass* pass, Operation* op) override; + + private: + std::optional file_path_; // Path to file to save the report. +}; + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_INSTRUMENTATIONS_SAVE_REPORT_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report_test.cc new file mode 100644 index 00000000000000..27d282dc309a7e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.h" + +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/test_base.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status_matchers.h" + +namespace mlir::quant::stablehlo { +namespace { + +using ::stablehlo::quantization::QuantizationResults; +using ::stablehlo::quantization::io::ReadFileToString; +using ::testing::SizeIs; +using ::testing::StrEq; +using ::tsl::protobuf::TextFormat; +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +using SaveQuantizationReportInstrumentationTest = QuantizationTestBase; + +TEST_F(SaveQuantizationReportInstrumentationTest, SaveReport) { + constexpr absl::string_view kModuleWithCompositeDotGeneral = R"mlir( + func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kModuleWithCompositeDotGeneral); + ASSERT_TRUE(module_op); + + // Create a pass manager with `SaveQuantizationReportInstrumentation` and + // `QuantizeCompositeFunctionsPass`. Run the passes against `module_op`. + PassManager pm(ctx_.get()); + + QuantizeCompositeFunctionsPassOptions options; + pm.addPass(createQuantizeCompositeFunctionsPass(options)); + + const std::string report_file_path = + absl::StrCat(testing::TempDir(), "/save_report.txtpb"); + pm.addInstrumentation(std::make_unique( + report_file_path)); + + const LogicalResult run_result = pm.run(*module_op); + ASSERT_TRUE(succeeded(run_result)); + + // Check that the report file contains `QuantizationResults` textproto, + // reflecting the quantization results, in this case the + // `composite_dot_general_fn` with quantized with `static_range_ptq` method. + const absl::StatusOr file_data = + ReadFileToString(report_file_path); + ASSERT_THAT(file_data, IsOk()); + + /* + results { + quantizable_unit { + name: "composite_dot_general_fn" + } + method { static_range_ptq { } } + } + */ + QuantizationResults results{}; + ASSERT_TRUE(TextFormat::ParseFromString(*file_data, &results)); + ASSERT_THAT(results.results(), SizeIs(1)); + EXPECT_THAT(results.results(0).quantizable_unit().name(), + StrEq("composite_dot_general_fn")); + EXPECT_TRUE(results.results(0).method().has_static_range_ptq()); +} + +TEST_F(SaveQuantizationReportInstrumentationTest, + ReportNotSavedWhenNoQuantizeCompositeFunctionsPass) { + constexpr absl::string_view kModuleWithCompositeDotGeneral = R"mlir( + func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { + %cst = "stablehlo.constant"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kModuleWithCompositeDotGeneral); + ASSERT_TRUE(module_op); + + // Create a pass manager with `SaveQuantizationReportInstrumentation` a pass + // that is not `QuantizeCompositeFunctionsPass`. Run the passes against + // `module_op`. + PassManager pm(ctx_.get()); + + pm.addPass(createPrepareQuantizePass()); + + const std::string report_file_path = absl::StrCat( + testing::TempDir(), + "/report_not_saved_no_quantize_composite_functions_pass.txtpb"); + pm.addInstrumentation(std::make_unique( + report_file_path)); + + const LogicalResult run_result = pm.run(*module_op); + ASSERT_TRUE(succeeded(run_result)); + + // The report file is not created because `QuantizeCompositeFunctionsPass` was + // not run. + EXPECT_THAT(ReadFileToString(report_file_path), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(SaveQuantizationReportInstrumentationTest, + ReportNotSavedWhenReportFilePathIsNullopt) { + constexpr absl::string_view kModuleWithCompositeDotGeneral = R"mlir( + func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { + %cst = "stablehlo.constant"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kModuleWithCompositeDotGeneral); + ASSERT_TRUE(module_op); + + PassManager pm(ctx_.get()); + + QuantizeCompositeFunctionsPassOptions options; + pm.addPass(createQuantizeCompositeFunctionsPass(options)); + pm.addInstrumentation(std::make_unique( + /*file_path=*/std::nullopt)); + + // The report file is not created and `SaveQuantizationReportInstrumentation` + // is not run, but the passes still run without errors. + const LogicalResult run_result = pm.run(*module_op); + ASSERT_TRUE(succeeded(run_result)); +} + +} // namespace +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc index 3018db7b2649e9..54b0744fcda94d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc @@ -118,6 +118,8 @@ std::unique_ptr GetStableHloOpQuantSpec(Operation* op) { if (auto optional_dim = GetDotGeneralQuantizationDim(dot_general_op); optional_dim) { spec->coeff_op_quant_dim[1] = optional_dim.value(); + } else { + spec->coeff_op_quant_dim[1] = -1; } if (function_name.contains("with_bias")) { spec->biases_params[2] = {{0, 1}, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc index cd861d934e75f8..5575a7516fccc9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc @@ -88,7 +88,7 @@ FailureOr GetUniformQuantizedType( } auto original_element_type = getElementTypeOrSelf(original_type); - if (!original_element_type.isa()) { + if (!mlir::isa(original_element_type)) { return rewriter.notifyMatchFailure( op, "Quantized type must be qint8 or qint32."); } @@ -112,7 +112,7 @@ FailureOr GetUniformQuantizedType( quantized_dimension, storage_type_min, storage_type_max); } - return original_type.cast().clone(elem_ty); + return mlir::cast(original_type).clone(elem_ty); } // If operand is TF const op, create MHLO constant op from the contents. @@ -178,8 +178,8 @@ FailureOr ConvertPaddingAttr( const xla::ConvolutionDimensionNumbers &dnums, PatternRewriter &rewriter) { StringAttr conv_padding = op.getPaddingAttr(); SmallVector padding_nums; - ShapedType lhs_shape = op.getLhs().getType().template cast(); - ShapedType rhs_shape = op.getRhs().getType().template cast(); + ShapedType lhs_shape = mlir::cast(op.getLhs().getType()); + ShapedType rhs_shape = mlir::cast(op.getRhs().getType()); // Handle only static shape cases. // TODO(b/260284866): Handle dynamic shape cases. @@ -192,26 +192,26 @@ FailureOr ConvertPaddingAttr( const int64_t padding_nums_size = 2 * (rhs_shape.getRank() - 2); padding_nums.reserve(padding_nums_size); - if (conv_padding.strref().equals("EXPLICIT")) { + if (conv_padding.strref() == "EXPLICIT") { for (auto padding_elem : op.getExplicitPaddingAttr().template getAsRange()) { padding_nums.push_back(padding_elem.getInt()); } - } else if (conv_padding.strref().equals("VALID")) { + } else if (conv_padding.strref() == "VALID") { padding_nums.resize(padding_nums_size, 0); } else { padding_nums.resize(padding_nums_size); for (int i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { const int64_t stride = - op.getWindowStridesAttr()[i].template cast().getInt(); + mlir::cast(op.getWindowStridesAttr()[i]).getInt(); const int64_t lhs_size_dilated = ::tensorflow::UniformQuantizedConvolutionParams::DilatedSize( lhs_shape.getDimSize(dnums.input_spatial_dimensions(i)), - op.getLhsDilationAttr()[i].template cast().getInt()); + mlir::cast(op.getLhsDilationAttr()[i]).getInt()); const int64_t rhs_size_dilated = ::tensorflow::UniformQuantizedConvolutionParams::DilatedSize( rhs_shape.getDimSize(dnums.kernel_spatial_dimensions(i)), - op.getRhsDilationAttr()[i].template cast().getInt()); + mlir::cast(op.getRhsDilationAttr()[i]).getInt()); const int64_t output_size = (lhs_size_dilated + stride - 1) / stride; const int64_t total_padding = std::max( @@ -262,7 +262,7 @@ FailureOr> ConvertToMhloConvolutionOpAttrs( attr.getName() == op.getLhsDilationAttrName() || attr.getName() == op.getRhsDilationAttrName()) { attr.setValue(ConvertToDenseElementsAttr( - attr.getValue().template cast(), rewriter)); + mlir::cast(attr.getValue()), rewriter)); converted_attrs.push_back(attr); } } @@ -362,9 +362,9 @@ class ConvertUniformQuantizeOp op->getLoc(), *output_type, op.getInput()); rewriter.replaceOpWithNewOp( op, - output_type->clone(output_type->getElementType() - .dyn_cast() - .getStorageType()), + output_type->clone( + mlir::dyn_cast(output_type->getElementType()) + .getStorageType()), result); return success(); @@ -438,9 +438,9 @@ class ConvertUniformRequantizeOp op->getLoc(), *output_type, input_quant); rewriter.replaceOpWithNewOp( op, - output_type->clone(output_type->getElementType() - .dyn_cast() - .getStorageType()), + output_type->clone( + mlir::dyn_cast(output_type->getElementType()) + .getStorageType()), result); return success(); } @@ -502,9 +502,9 @@ class ConvertUniformQuantizedDotOp /*precision_config=*/nullptr); rewriter.replaceOpWithNewOp( op, - output_type->clone(output_type->getElementType() - .dyn_cast() - .getStorageType()), + output_type->clone( + mlir::dyn_cast(output_type->getElementType()) + .getStorageType()), result); return success(); } @@ -564,9 +564,9 @@ class ConvertUniformQuantizedConvolutionOp op->getLoc(), *output_type, operands, *converted_attrs_or); rewriter.replaceOpWithNewOp( op, - output_type->clone(output_type->getElementType() - .dyn_cast() - .getStorageType()), + output_type->clone( + mlir::dyn_cast(output_type->getElementType()) + .getStorageType()), result); return success(); } @@ -582,7 +582,7 @@ class ConvertUniformQuantizedAddOp ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getLhs(); - auto lhs_type = lhs.getType().cast(); + auto lhs_type = mlir::cast(lhs.getType()); if (!lhs_type.hasRank()) { return rewriter.notifyMatchFailure( op, "Legalization supports cases where only lhs rank known."); @@ -632,9 +632,9 @@ class ConvertUniformQuantizedAddOp op->getLoc(), *output_type, lhs, *rhs_or, broadcast_dims); rewriter.replaceOpWithNewOp( op, - output_type->clone(output_type->getElementType() - .dyn_cast() - .getStorageType()), + output_type->clone( + mlir::dyn_cast(output_type->getElementType()) + .getStorageType()), result); return success(); } @@ -692,9 +692,9 @@ class ConvertUniformQuantizedClipByValueOp op->getLoc(), *output_type, res_min_clipped, *max_or, broadcast_dims); rewriter.replaceOpWithNewOp( op, - output_type->clone(output_type->getElementType() - .dyn_cast() - .getStorageType()), + output_type->clone( + mlir::dyn_cast(output_type->getElementType()) + .getStorageType()), res_max_clipped); return success(); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types.cc index 65192fc1117673..f07097a109a0af 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types.cc @@ -71,7 +71,7 @@ bool IsIllegalType(Type type) { // If input is not TF qint types, returns the original type. Type ToLegalType(Type type) { if (IsTFQintType(type)) return GetIntTypeFromTFQint(type); - if (auto shaped = type.dyn_cast()) { + if (auto shaped = mlir::dyn_cast(type)) { Type elem = shaped.getElementType(); if (IsTFQintType(elem)) return shaped.clone(ToLegalType(elem)); } @@ -289,7 +289,7 @@ class TFConstOpQuantToIntPattern : public OpConversionPattern { } auto dense_attr_or = GetDenseAttrFromTensorProtoAttr( tensor_proto_attr.getValue(), - ToLegalType(op.getOutput().getType()).dyn_cast()); + mlir::dyn_cast(ToLegalType(op.getOutput().getType()))); if (failed(dense_attr_or)) { op->emitError("failed to get DenseElementAttr."); return failure(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc index 2825195addea12..7484ed89aa51b1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -53,7 +54,7 @@ class VerifyQuantLegalization bool IsQuantType(Type type) { auto element_type = getElementTypeOrSelf(type); - return element_type.isa() || + return mlir::isa(element_type) || IsTFQintType(element_type); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc index 0204a19452bb0d..4a85786dc94937 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc @@ -143,8 +143,8 @@ class BFloat16TypePattern : public ConversionPattern { state.attributes.set( const_op.getValueAttrName(), DenseFPElementsAttr::get( - const_op.getValue().getType().dyn_cast().clone( - rewriter.getBF16Type()), + mlir::dyn_cast(const_op.getValue().getType()) + .clone(rewriter.getBF16Type()), bfloat16_values)); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc index 5be09ce2ad47ef..686204030c1fdc 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc @@ -155,7 +155,7 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp PatternRewriter& rewriter) const override { auto transpose_op = cast(op.getOperand(0).getDefiningOp()); - const auto result_type = op.getResult(0).getType().cast(); + const auto result_type = mlir::cast(op.getResult(0).getType()); const SmallVector new_result_shape = Permute(result_type.getShape(), kNchwToNhwcPermutation); @@ -169,16 +169,16 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp op.getLoc(), new_result_type, transpose_op.getOperand(), /*init_value=*/op.getOperand(1), /*window_dimensions=*/ - PermuteI64ArrayAttr(rewriter, op.getWindowDimensionsAttr(), + PermuteI64ArrayAttr(rewriter, op.getWindowDimensions(), kNchwToNhwcPermutation), /*window_strides=*/ - PermuteI64ArrayAttr(rewriter, op.getWindowStridesAttr(), + PermuteI64ArrayAttr(rewriter, op.getWindowStrides(), kNchwToNhwcPermutation), /*base_dilations=*/ - PermuteI64ArrayAttr(rewriter, op.getBaseDilationsAttr(), + PermuteI64ArrayAttr(rewriter, op.getBaseDilations(), kNchwToNhwcPermutation), /*window_dilations=*/ - PermuteI64ArrayAttr(rewriter, op.getWindowDilationsAttr(), + PermuteI64ArrayAttr(rewriter, op.getWindowDilations(), kNchwToNhwcPermutation), /*padding=*/DenseIntElementsAttr(nullptr)); @@ -199,12 +199,13 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp // `array_attr` and `permutation` must be equal. Returns a null attribute // if `array_attr` is null. DenseI64ArrayAttr PermuteI64ArrayAttr( - PatternRewriter& rewriter, const DenseI64ArrayAttr array_attr, + PatternRewriter& rewriter, + const std::optional> array_attr, const ArrayRef permutation) const { - if (array_attr == nullptr) return DenseI64ArrayAttr(nullptr); + if (!array_attr.has_value()) return DenseI64ArrayAttr(nullptr); return rewriter.getDenseI64ArrayAttr( - Permute(array_attr, permutation)); + Permute(array_attr.value(), permutation)); } LogicalResult MatchMaxPoolReduceWindowOp( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc index 051745c0d6792b..06e38c3935c417 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc @@ -127,12 +127,13 @@ class FoldTransposedConstantOp if (!const_op) return failure(); // Only support float tensors. - auto tensor_type = const_op.getType().dyn_cast_or_null(); + auto tensor_type = mlir::dyn_cast_or_null(const_op.getType()); if (!tensor_type || !tensor_type.getElementType().isF32()) { return failure(); } - return success(const_op.getValue().isa_and_nonnull()); + return success( + mlir::isa_and_nonnull(const_op.getValue())); } void rewrite(mlir::stablehlo::TransposeOp op, @@ -140,7 +141,8 @@ class FoldTransposedConstantOp auto const_op = cast(op.getOperand().getDefiningOp()); - const auto value_attr = const_op.getValue().cast(); + const auto value_attr = + mlir::cast(const_op.getValue()); const ArrayRef original_shape = value_attr.getShapedType().getShape(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_calibration_statistics_saver.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_calibration_statistics_saver.cc new file mode 100644 index 00000000000000..8cb0b645c312cf --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_calibration_statistics_saver.cc @@ -0,0 +1,189 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep +#include "tsl/platform/path.h" + +namespace mlir::quant::stablehlo { +namespace { + +std::string GetOutputFilePath(absl::string_view calibration_data_dir, + absl::string_view func_name, + int32_t output_file_idx) { + return tsl::io::JoinPath(calibration_data_dir, + llvm::Twine(func_name) + .concat("_") + .concat(std::to_string(output_file_idx)) + .concat(".pb") + .str()); +} + +// Finds `CustomAggregator` ops and collects their outputs and attributes. +void FindCustomAggregatorOps( + Region& region, + const std::unordered_set& aggregator_ops_to_ignore, + SmallVector& statistics_outputs, SmallVector& ids, + SmallVector& calibration_methods) { + for (auto op : region.getOps()) { + if (aggregator_ops_to_ignore.count(op.getId().str())) continue; + + ids.push_back(op.getId()); + calibration_methods.push_back(op.getCalibrationMethod()); + statistics_outputs.push_back(op.getMin()); + statistics_outputs.push_back(op.getMax()); + statistics_outputs.push_back(op.getHistogram()); + } +} + +// Inserts a `CalibrationStatisticsSaverOp` to the end of the region. +LogicalResult InsertCalibrationStatisticsSaverOp( + Region& region, MLIRContext& ctx, absl::string_view output_file_path, + const std::unordered_set& aggregator_ops_to_ignore) { + SmallVector statistics_outputs; + SmallVector ids; + SmallVector calibration_methods; + FindCustomAggregatorOps(region, aggregator_ops_to_ignore, statistics_outputs, + ids, calibration_methods); + if (statistics_outputs.empty()) return failure(); + + OpBuilder builder(&ctx); + // Set the insertion point right before the return op. + builder.setInsertionPoint(®ion.back().back()); + + StringAttr output_file_path_attr = builder.getStringAttr(output_file_path); + ArrayAttr ids_attr = builder.getStrArrayAttr(ids); + ArrayAttr calibration_methods_attr = + builder.getI32ArrayAttr(calibration_methods); + builder.create( + region.getLoc(), statistics_outputs, output_file_path_attr, ids_attr, + calibration_methods_attr); + return success(); +} + +// Returns true if the op contains a `CalibrationStatisticsSaverOp`. +bool ContainCalibrationStatisticsSaverOp(Operation* op) { + // Check the region for CaseRegionOp, IfRegionOp and WhileRegionOp. + for (Region& region : op->getRegions()) { + if (!region.getOps().empty()) { + return true; + } + } + + SymbolTable symbol_table(op->getParentOfType()); + // Check the functions associated to CaseOp, IfOp and WhileOp. + for (const NamedAttribute& attr : op->getAttrs()) { + FlatSymbolRefAttr symbol_attr = + dyn_cast_or_null(attr.getValue()); + if (!symbol_attr) continue; + + func::FuncOp target_func = dyn_cast_or_null( + symbol_table.lookup(symbol_attr.getValue())); + if (!target_func) continue; + + if (!target_func.getBody() + .getOps() + .empty()) { + return true; + } + } + return false; +} + +} // namespace + +#define GEN_PASS_DECL_INSERTCALIBRATIONSTATISTICSSAVERPASS +#define GEN_PASS_DEF_INSERTCALIBRATIONSTATISTICSSAVERPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" + +class InsertCalibrationStatisticsSaverPass + : public impl::InsertCalibrationStatisticsSaverPassBase< + InsertCalibrationStatisticsSaverPass> { + public: + using impl::InsertCalibrationStatisticsSaverPassBase< + InsertCalibrationStatisticsSaverPass>:: + InsertCalibrationStatisticsSaverPassBase; + + private: + void runOnOperation() override; +}; + +void InsertCalibrationStatisticsSaverPass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext& ctx = getContext(); + + std::unordered_set aggregator_ops_to_ignore( + aggregator_ops_to_ignore_.begin(), aggregator_ops_to_ignore_.end()); + + // Insert CalibrationStatisticsSaverOp to the end of each region. + for (auto func_op : module_op.getOps()) { + int32_t output_file_idx = 0; + StringRef func_name = func_op.getSymName(); + + func_op.walk([&output_file_idx, &ctx, &func_name, &aggregator_ops_to_ignore, + this](Operation* op) { + for (Region& region : op->getRegions()) { + if (succeeded(InsertCalibrationStatisticsSaverOp( + region, ctx, + GetOutputFilePath(calibration_data_dir_, func_name, + output_file_idx), + aggregator_ops_to_ignore))) { + ++output_file_idx; + }; + } + }); + } + + // Control flow ops that contains CalibrationStatisticsSaver ops must be set + // to stateful, otherwise the op will not be executed. + OpBuilder builder(&ctx); + module_op.walk([&builder](Operation* op) { + if (op->hasAttrOfType("is_stateless") && + ContainCalibrationStatisticsSaverOp(op)) { + op->setAttr("is_stateless", builder.getBoolAttr(false)); + } + }); +} + +std::unique_ptr> +CreateInsertCalibrationStatisticsSaverPass( + StringRef calibration_data_dir, + const std::vector& aggregator_ops_to_ignore) { + InsertCalibrationStatisticsSaverPassOptions options = { + .aggregator_ops_to_ignore_ = aggregator_ops_to_ignore, + .calibration_data_dir_ = calibration_data_dir.str(), + }; + return std::make_unique(options); +} + +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc index 9fb1e9e985d15e..28396ec71ab07e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc @@ -13,20 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project // IWYU pragma: keep #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -41,6 +45,7 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -52,6 +57,10 @@ namespace mlir::quant::stablehlo { namespace { +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::QuantizedType; +using ::stablehlo::quantization::WeightOnlyPtq; + // Inserts quantization parameters of weights for weight-only quantization and // dynamic range quantization of `stablehlo.convolution` and // `stablehlo.dot_general`. @@ -81,45 +90,58 @@ class InsertWeightParamPattern if (op->getNumResults() != 1) { return failure(); } - auto type = op->getResult(0).getType().cast(); + auto type = mlir::cast(op->getResult(0).getType()); if (!type || !type.getElementType().isF32()) { return failure(); } - return success(op->hasOneUse() && - IsWeightQuantizableFunction(*op->getUses().begin())); + return success( + op->hasOneUse() && + IsWeightQuantizableFunction(*op->getUses().begin(), type.getRank())); } // Checks if the operand is second operand of `tf.XlaCallModule` op for // `stablehlo.convolution` or `stablehlo.dot_general` with fully_quantizable // trait. - static bool IsWeightQuantizableFunction(OpOperand& operand) { + static bool IsWeightQuantizableFunction(OpOperand& operand, int64_t rank) { if (operand.getOperandNumber() != 1) { return false; } Operation* user = operand.getOwner(); - if (isa(user)) { - auto call_op = cast(user); - const StringRef function_name = GetEntryFunctionName(call_op); - const bool is_conv_or_dot = function_name.contains("conv") || - function_name.contains("dot_general"); - const bool has_quant_trait = HasQuantizableTrait(call_op); - return is_conv_or_dot && has_quant_trait; + if (!IsWeightOnlyQuantizableOp(*user)) { + return false; } - return false; + Method method = GetQuantizationMethodOrDefault(user); + return HasValidWeightOnlyPtqMethod(method.weight_only_ptq(), rank); } void rewrite(Operation* op, PatternRewriter& rewriter) const override { Operation* quantizable_op = *op->getUsers().begin(); DenseFPElementsAttr attr; - if (!matchPattern(op->getResult(0), m_Constant(&attr))) { - return; + matchPattern(op->getResult(0), m_Constant(&attr)); + + Method method = GetQuantizationMethodOrDefault(quantizable_op); + const WeightOnlyPtq& weight_only_ptq = method.weight_only_ptq(); + + Type weight_type; + if (IsPerTensor(weight_only_ptq)) { + weight_type = dyn_cast( + quant::GetUniformQuantizedTypeForWeight( + attr, /*symmetric=*/true, /*num_bits=*/8, /*is_signed=*/true, + /*narrow_range=*/true, /*legacy_float_scale=*/false)); + } else { + int quantization_dimension = GetQuantizationDimension( + weight_only_ptq, cast(quantizable_op)); + weight_type = quant::GetUniformQuantizedPerAxisTypeForWeight( + attr, quantization_dimension, /*symmetric=*/true, /*num_bits=*/8, + /*is_signed=*/true, + /*narrow_range=*/true, /*legacy_float_scale=*/false); } - auto quant_type = - quant::GetUniformQuantizedTypeForWeight( - attr, /*symmetric=*/false, /*num_bits=*/8, /*is_signed=*/true, - /*narrow_range=*/false, /*legacy_float_scale=*/false) - .template dyn_cast(); + + auto quant_type = dyn_cast(weight_type); if (!quant_type) { + op->emitError( + "Failed to get weight quantization parameters for weight-only " + "quantization."); return; } @@ -134,6 +156,80 @@ class InsertWeightParamPattern expressed_type, q); quantizable_op->setOperand(1, dq.getResult()); } + + private: + static bool HasValidWeightOnlyPtqMethod(const WeightOnlyPtq& weight_only_ptq, + int64_t rank) { + const auto& input_quantized_types = weight_only_ptq.input_quantized_types(); + if (IsPerTensor(weight_only_ptq)) { + return true; + } + // `input_quantized_types` should contain spec for quantization type of the + // second operand, which is weight. + const QuantizedType& quantized_type = input_quantized_types.at(1); + if (const auto& specs = quantized_type.dimension_specs(); + specs.has_dimension()) { + return specs.dimension() >= 0 && specs.dimension() < rank; + } + return true; + } + + static bool IsPerTensor(const WeightOnlyPtq& weight_only_ptq) { + const auto& input_quantized_types = weight_only_ptq.input_quantized_types(); + if (input_quantized_types.empty()) { + return true; + } + auto weight_type = input_quantized_types.find(1); + if (weight_type == input_quantized_types.end()) { + return true; + } + return weight_type->second.has_per_tensor(); + } + + static int GetQuantizationDimension(const WeightOnlyPtq& weight_only_ptq, + TF::XlaCallModuleOp op) { + const QuantizedType& quantized_type = + weight_only_ptq.input_quantized_types().at(1); + if (quantized_type.dimension_specs().has_dimension()) { + return quantized_type.dimension_specs().dimension(); + } + return GetDefaultQuantizationDimension(op); + } + + // Determines quantization dimension of weights for given `tf.XlaCallModule` + // op. For convolution, returns output feature dimension of the kernel. For + // dot_general, returns the first non-contracting dimension, non-batching + // dimension. If such dimension does not exists, returns the last dimension of + // rhs. + static int64_t GetDefaultQuantizationDimension(TF::XlaCallModuleOp op) { + const StringRef function_name = GetEntryFunctionName(op); + const auto module_op = op->getParentOfType(); + const SymbolTable symbol_table(module_op); + func::FuncOp func = symbol_table.lookup(function_name); + + if (function_name.contains("conv")) { + return (*(func.getOps().begin())) + .getDimensionNumbers() + .getKernelOutputFeatureDimension(); + } else if (function_name.contains("dot_general")) { + auto dot = *(func.getOps().begin()); + const ::mlir::stablehlo::DotDimensionNumbersAttr dimension_numbers = + dot.getDotDimensionNumbers(); + ArrayRef rhs_contracting_dims = + dimension_numbers.getRhsContractingDimensions(); + ArrayRef rhs_batching_dims = + dimension_numbers.getRhsBatchingDimensions(); + int64_t rank = dot.getRhs().getType().cast().getRank(); + for (int i = 0; i < rank; ++i) { + // Return the first non-contracting, non-batching dimension of rhs. + if (llvm::find(rhs_contracting_dims, i) == rhs_contracting_dims.end() && + llvm::find(rhs_batching_dims, i) == rhs_batching_dims.end()) { + return i; + } + } + } + return op.getOperand(1).getType().cast().getRank() - 1; + } }; void InsertWeightParamPass::runOnOperation() { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc index 6577666ab90f10..d5487dd5ad8abd 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc @@ -66,7 +66,7 @@ Attribute DefaultOrNullAttr(OpBuilder& builder, const Attribute& attr) { // Checks whether the value of a constant equals the given float, regardless // of the tensor dimension. bool FloatValueEquals(const Attribute& attr, const double value) { - const auto fp_attr = attr.dyn_cast_or_null(); + const auto fp_attr = mlir::dyn_cast_or_null(attr); if (!fp_attr) return false; if (fp_attr.isSplat()) { @@ -208,7 +208,9 @@ void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() { simple_patterns::populateWithGenerated(patterns); fusion_patterns::populateWithGenerated(patterns); FrozenRewritePatternSet frozen_patterns(std::move(patterns)); - for (auto func : module_op.getOps()) { + + // Iterate over the sorted list of functions to keep order deterministic. + for (func::FuncOp func : GetSortedFunctions(module_op)) { if (failed(applyPatternsAndFoldGreedily(func, frozen_patterns))) { func.emitError() << "quant-stablehlo-lift-quantizable-spots-as-functions failed."; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc index acfe3cfd6fc6b2..9a0d8fb2a25b2b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc @@ -69,8 +69,8 @@ class MergeFusionWithUniformDequantizePattern auto func_name = call_op.getCallee(); if (!func_name.starts_with("quantized_")) return failure(); if (call_op->getNumResults() != 1) return failure(); - if (!getElementTypeOrSelf(call_op->getResult(0).getType()) - .isa()) + if (!mlir::isa( + getElementTypeOrSelf(call_op->getResult(0).getType()))) return failure(); // Fetch the callee function. @@ -89,8 +89,8 @@ class MergeFusionWithUniformDequantizePattern // Create a new func.call op with f32 output. auto new_call_op = call_op.clone(); new_call_op->getResult(0).setType( - call_op.getResult(0).getType().cast().clone( - rewriter.getF32Type())); + mlir::cast(call_op.getResult(0).getType()) + .clone(rewriter.getF32Type())); rewriter.setInsertionPoint(call_op); rewriter.insert(new_call_op); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc index 521f701598fb0a..ed2da6ed103273 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc @@ -73,7 +73,7 @@ class RewriteNchwConvolutionToNhwc // Transpose the input tensor: [b, f, 0, 1] => [b, 0, 1, f] Value input = op->getOperand(0); const TensorType new_input_tensor_type = GetTransposedTensorType( - input.getType().cast(), kNchwToNhwcPermutation); + mlir::cast(input.getType()), kNchwToNhwcPermutation); auto input_transpose_op = rewriter.create( op.getLoc(), /*resultType0=*/new_input_tensor_type, /*operand=*/input, @@ -82,7 +82,7 @@ class RewriteNchwConvolutionToNhwc // Transpose the filter tensor: [o, i, 0, 1] => [0, 1, i, o] Value filter = op->getOperand(1); const TensorType new_filter_tensor_type = GetTransposedTensorType( - filter.getType().cast(), kOihwToHwioPermutation); + mlir::cast(filter.getType()), kOihwToHwioPermutation); auto filter_transpose_op = rewriter.create( op.getLoc(), /*resultType0=*/new_filter_tensor_type, /*operand=*/filter, @@ -98,7 +98,8 @@ class RewriteNchwConvolutionToNhwc /*outputSpatialDimensions=*/SmallVector{1, 2}); // Determine the shape of the output tensor: [b, f, 0, 1] => [b, 0, 1, f] - auto output_tensor_type = op->getResult(0).getType().cast(); + auto output_tensor_type = + mlir::cast(op->getResult(0).getType()); const TensorType new_conv_output_tensor_type = GetTransposedTensorType(output_tensor_type, kNchwToNhwcPermutation); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h index 2937eec8d9a2f0..d13c589c2ba890 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "absl/status/statusor.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -44,6 +45,12 @@ std::unique_ptr> CreateLiftQuantizableSpotsAsFunctionsPass( const ::stablehlo::quantization::QuantizationSpecs& quantization_specs); +// Creates a pass that inserts CalibrationStatisticsSaverOp. +std::unique_ptr> +CreateInsertCalibrationStatisticsSaverPass( + StringRef calibration_data_dir, + const std::vector& aggregator_ops_to_ignore); + // Adds generated pass default constructors or options definitions. #define GEN_PASS_DECL // Adds generated pass registration functions. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index fdb7fa7941f025..7661e8d562fbe9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -63,10 +63,6 @@ def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-function Option<"mlir_dump_file_name_", "mlir-dump-file-name", "std::optional", /*default=*/"std::nullopt", "MLIR dump file name.">, - Option<"enable_weight_only_", - "enable-weight-only", - "bool", /*default=*/"false", - "Whether to produce weight-only quantized op for convolution and dot_general op.">, Option<"merge_fusion_with_dequantize_", "merge-fusion-with-dequantize", "bool", /*default=*/"false", @@ -106,10 +102,6 @@ def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> { "enable-per-channel-quantized-weight", "bool", /*default=*/"true", "Whether to enable per-channel quantized weights.">, - Option<"enable_weight_only_", - "enable-weight-only", - "bool", /*default=*/"false", - "Whether to produce weight-only quantized op for convolution and dot_general op.">, ]; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", @@ -228,3 +220,20 @@ def RemoveShardingCustomCallPass : Pass<"stablehlo-remove-sharding-custom-call", }]; let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; } + +def InsertCalibrationStatisticsSaverPass : Pass<"stablehlo-insert-calibration-statistics-saver", "ModuleOp"> { + let summary = "Inserts `CalibrationStatisticsSaver` op to collect and save calibration statistics."; + let description = [{ + Finds all `CustomAggregator` ops in the each function and add a single + `CalibrationStatisticsSaver` op at the end of the function to collect their + statistics. + }]; + let options = [ + ListOption<"aggregator_ops_to_ignore_", "aggregator-ops-to-ignore", "std::string", + "Ops to ignore when inserting CalibrationStatisticsSaver.">, + Option<"calibration_data_dir_", "calibration-data-dir", + "std::string", /*default=*/"", + "The directory to save calibration data.">, + ]; + let dependentDialects = ["TF::TensorFlowDialect"]; +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc index a6d041a5b8cb9e..787fca3594f14a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -75,6 +75,7 @@ using ::mlir::stablehlo::GetDimensionSizeOp; using ::mlir::stablehlo::ReshapeOp; using ::mlir::stablehlo::UniformQuantizeOp; using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::QuantizedDimension; using ::stablehlo::quantization::QuantizedType; using ::stablehlo::quantization::StaticRangePtq; @@ -237,7 +238,7 @@ void CreateAndReturnQuantizedBiasPattern( if (succeeded(bcast_op)) { Value bcast_op_result = (*bcast_op)->getResult(0); auto bcast_op_result_type = - bcast_op_result.getType().cast(); + mlir::cast(bcast_op_result.getType()); const ArrayRef bcast_shape = bcast_op_result_type.getShape(); const TensorType new_bcast_op_result_type = bcast_op_result_type.cloneWith( bcast_shape, accumulation_quantized_element_type); @@ -245,7 +246,7 @@ void CreateAndReturnQuantizedBiasPattern( } const auto add_op_result_type = - add_op_result.getType().cast(); + mlir::cast(add_op_result.getType()); const ArrayRef add_op_shape = add_op_result_type.getShape(); // For quantized bias add case, lhs, rhs, and result have the same types. const TensorType new_add_op_result_type = add_op_result_type.cloneWith( @@ -269,7 +270,8 @@ class EntryFuncBodyQuantizationPattern { // Returns `success()` if `entry_func_op`'s body is eligible for rewriting. At // this point `entry_func_op`'s signature has not been reset with quantized // types. - virtual LogicalResult match(func::FuncOp entry_func_op) const = 0; + virtual LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const = 0; // Rewrites the `entry_func_op`'s body. virtual void rewrite(func::FuncOp entry_func_op, @@ -318,7 +320,7 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, Value gemm_style_op_result = gemm_style_op->getResult(0); const auto gemm_style_op_result_type = - gemm_style_op_result.getType().cast(); + mlir::cast(gemm_style_op_result.getType()); const ArrayRef gemm_style_shape = gemm_style_op_result_type.getShape(); @@ -326,11 +328,12 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, TensorType new_gemm_style_op_result_type; const double input_scale = - getElementTypeOrSelf(input_type).cast().getScale(); + mlir::cast(getElementTypeOrSelf(input_type)) + .getScale(); if (enable_per_channel_quantized_weight) { - ArrayRef filter_scales = getElementTypeOrSelf(filter_type) - .cast() + ArrayRef filter_scales = mlir::cast( + getElementTypeOrSelf(filter_type)) .getScales(); std::vector result_scales; result_scales.reserve(filter_scales.size()); @@ -340,8 +343,8 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, } const ArrayRef zero_points = - getElementTypeOrSelf(filter_type) - .cast() + mlir::cast( + getElementTypeOrSelf(filter_type)) .getZeroPoints(); // `stablehlo.convolution` assumes the following format: @@ -351,7 +354,7 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, // `stablehlo.dot_general` legalizable to `tfl.fully_connected` has a // filter rank of 2 with the last dimension as the channel dimension. const int64_t quantization_dimension = - filter_type.cast().getShape().size() - 1; + mlir::cast(filter_type).getShape().size() - 1; accumulation_quantized_element_type = CreateI32F32UniformQuantizedPerAxisType( gemm_style_op->getLoc(), *rewriter.getContext(), result_scales, @@ -360,9 +363,9 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, new_gemm_style_op_result_type = gemm_style_op_result_type.cloneWith( gemm_style_shape, accumulation_quantized_element_type); } else { - const double filter_scale = getElementTypeOrSelf(filter_type) - .cast() - .getScale(); + const double filter_scale = + mlir::cast(getElementTypeOrSelf(filter_type)) + .getScale(); const double result_scale = input_scale * filter_scale; accumulation_quantized_element_type = CreateI32F32UniformQuantizedType( @@ -408,19 +411,20 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeDotGeneralOpPattern( - const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) + const bool enable_per_channel_quantized_weight) : enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight), - enable_weight_only_(enable_weight_only) {} + enable_per_channel_quantized_weight) {} - LogicalResult match(func::FuncOp entry_func_op) const override { + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_static_range_ptq()) { + return failure(); + } return MatchGemmStyleOp(entry_func_op); } void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, PatternRewriter& rewriter) const override { - if (enable_weight_only_) return; DotGeneralOp dot_general_op = *entry_func_op.getOps().begin(); const bool should_quantize_per_channel = enable_per_channel_quantized_weight_ && @@ -433,28 +437,26 @@ class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { [[deprecated( "Do not rely on this field for per-channel quantization. Use `Method` " "instead.")]] const bool enable_per_channel_quantized_weight_; - // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform - // weight-only quantization. - const bool enable_weight_only_; }; // Quantizes the entry function's body containing a `ConvolutionOp`. class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeConvolutionOpPattern( - const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) + const bool enable_per_channel_quantized_weight) : enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight), - enable_weight_only_(enable_weight_only) {} + enable_per_channel_quantized_weight) {} - LogicalResult match(func::FuncOp entry_func_op) const override { + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_static_range_ptq()) { + return failure(); + } return MatchGemmStyleOp(entry_func_op); } void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, PatternRewriter& rewriter) const override { - if (enable_weight_only_) return; RewriteGemmStyleOp( entry_func_op, rewriter, enable_per_channel_quantized_weight_ && @@ -463,7 +465,8 @@ class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { // Returns true if the quantization method indicates per-channel quantization // for convolution weights. This method specifically matches a quantization - // dimension of 3 for the input index 1. + // dimension of 3 for the input index 1 or unspecified quantization dimension + // for the input index 1. bool IsWeightPerChannelQuantized(const Method& quantization_method) const { if (quantization_method.has_static_range_ptq()) { const StaticRangePtq& static_range_ptq_spec = @@ -472,7 +475,13 @@ class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { if (static_range_ptq_spec.input_quantized_types().contains(1)) { const QuantizedType& weight_quantized_type = static_range_ptq_spec.input_quantized_types().at(1); - return weight_quantized_type.dimension_specs().dimension() == 3; + if (weight_quantized_type.has_per_tensor()) { + return false; + } + const QuantizedDimension& dimension_specs = + weight_quantized_type.dimension_specs(); + return !dimension_specs.has_dimension() || + dimension_specs.dimension() == 3; } } return false; @@ -482,25 +491,60 @@ class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { [[deprecated( "Do not rely on this field for per-channel quantization. Use `Method` " "instead.")]] const bool enable_per_channel_quantized_weight_; - // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform - // weight-only quantization. - const bool enable_weight_only_; +}; + +// Quantizes the entry function's body for weight-only quantized op. +template +class QuantizeWeightOnlyOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeWeightOnlyOpPattern( + const bool enable_per_channel_quantized_weight) + : enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} + + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_weight_only_ptq()) { + return failure(); + } + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, + PatternRewriter& rewriter) const override {} + + private: + [[deprecated( + "Do not rely on this field for per-channel quantization. Use `Method` " + "instead.")]] const bool enable_per_channel_quantized_weight_; }; template class QuantizeSingularOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeSingularOpPattern( - const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) {} + const bool enable_per_channel_quantized_weight) {} - LogicalResult match(func::FuncOp entry_func_op) const override { + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_static_range_ptq()) { + return failure(); + } const auto op_iterator_range = entry_func_op.getOps(); if (op_iterator_range.empty()) { LLVM_DEBUG(llvm::dbgs() << "Function does not have " << SingularOpT::getOperationName() << " op.\n"); return failure(); } + + // Entry function body should have one block with two ops(op to be quantized + // and return op). + Region& body = entry_func_op.getBody(); + if (body.getBlocks().size() != 1 || + body.begin()->getOperations().size() != 2) { + return failure(); + } + if (!isa( (*op_iterator_range.begin()).getResult().getType())) { LLVM_DEBUG(llvm::dbgs() << SingularOpT::getOperationName() @@ -526,13 +570,13 @@ class QuantizeSingularOpPattern : public EntryFuncBodyQuantizationPattern { // Get the quantized tensor manipulation op's output type and update. const auto singular_op_result_type = - singular_op_result.getType().cast(); + mlir::cast(singular_op_result.getType()); const ArrayRef singular_op_shape = singular_op_result_type.getShape(); const TensorType new_singular_op_result_type = singular_op_result_type.cloneWith( - singular_op_shape, - getElementTypeOrSelf(operand_type).cast()); + singular_op_shape, mlir::cast( + getElementTypeOrSelf(operand_type))); singular_op_result.setType(new_singular_op_result_type); // Create requantization op and return. @@ -599,9 +643,9 @@ void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( const EntryFuncBodyQuantizationPattern& body_rewrite_pattern, const Method& quantization_method) { const ModuleOp module_op = xla_call_module_op->getParentOfType(); - const SymbolTable symbol_table(module_op); - func::FuncOp entry_func_op = GetEntryFuncOp(xla_call_module_op, symbol_table); + func::FuncOp entry_func_op = + GetEntryFuncOp(xla_call_module_op, SymbolTable(module_op)); QuantizeEntryFuncOp(ctx, rewriter, xla_call_module_op, entry_func_op, body_rewrite_pattern, quantization_method); @@ -627,16 +671,13 @@ template { public: explicit XlaCallModuleOpToCallOp( - MLIRContext& ctx, const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) + MLIRContext& ctx, const bool enable_per_channel_quantized_weight) : OpRewritePattern(&ctx), enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight), - enable_weight_only_(enable_weight_only) {} + enable_per_channel_quantized_weight) {} LogicalResult match(TF::XlaCallModuleOp op) const override { ModuleOp module_op = op->getParentOfType(); - SymbolTable symbol_table(module_op); // Ignore ops without quantization method. // Consider adding checks for individual methods. @@ -646,19 +687,18 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { if (!IsQuantizedXlaCallModuleOp(op)) return failure(); // For weight-only quantization, op should be hybrid quantized. - if (enable_weight_only_ && !IsHybridQuantizedOp(op)) { + if (HasWeightOnlyPtqMethod(op) && !IsHybridQuantizedOp(op)) { return failure(); } - func::FuncOp entry_func_op = GetEntryFuncOp(op, symbol_table); + func::FuncOp entry_func_op = GetEntryFuncOp(op, SymbolTable(module_op)); if (!entry_func_op) { op->emitError("Failed to find a valid entry function."); return failure(); } - - return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_, - enable_weight_only_) - .match(entry_func_op); + Method quantization_method = GetQuantizationMethodOrDefault(op); + return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) + .match(entry_func_op, quantization_method); } void rewrite(TF::XlaCallModuleOp xla_call_module_op, @@ -671,8 +711,7 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( *rewriter.getContext(), rewriter, xla_call_module_op, - FuncBodyRewritePatternT(enable_per_channel_quantized_weight_, - enable_weight_only_), + FuncBodyRewritePatternT(enable_per_channel_quantized_weight_), quantization_method); } @@ -680,9 +719,6 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { [[deprecated( "Do not rely on this field for per-channel quantization. Use `Method` " "instead.")]] const bool enable_per_channel_quantized_weight_; - // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform - // weight-only quantization. - const bool enable_weight_only_; }; // Quantizes op with regions such as stablehlo.reduce_window op. @@ -733,13 +769,13 @@ class QuantizeOpWithRegionPattern inputs.reserve(op_with_region->getNumOperands()); for (Value operand : op_with_region->getOperands()) { const Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (mlir::isa(operand_type)) { inputs.push_back(operand); continue; } const Type element_type = - operand.getType().cast().getElementType(); + mlir::cast(operand.getType()).getElementType(); if (auto dq_op = dyn_cast_or_null( operand.getDefiningOp())) { inputs.push_back(dq_op.getOperand()); @@ -759,13 +795,13 @@ class QuantizeOpWithRegionPattern output_types.reserve(op_with_region->getNumResults()); for (const Value result : op_with_region->getResults()) { const Type result_type = result.getType(); - if (result_type.isa()) { + if (mlir::isa(result_type)) { outputs_replaced.push_back(result); output_types.push_back(result_type); continue; } const Type result_element_type = - result.getType().cast().getElementType(); + mlir::cast(result.getType()).getElementType(); // If the user is the QuantizeOp, it must be the only user. if (result.hasOneUse() && isa(*result.user_begin())) { @@ -799,7 +835,7 @@ class QuantizeOpWithRegionPattern const Type operand_type = quantized_op->getOperandTypes()[0]; const Type element_type = - operand_type.cast().getElementType(); + mlir::cast(operand_type).getElementType(); for (Region& region : quantized_op->getRegions()) { ReplaceTypesInNestedRegion(region, element_type); } @@ -856,7 +892,7 @@ class QuantizeOpWithRegionPattern // Replaces element type of the given tensor type while preserving shape of // the given type. If the given type is not tensor type, just return itself. Type ReplaceElementType(const Type type, const Type element_type) const { - if (TensorType tensor_type = type.dyn_cast()) { + if (TensorType tensor_type = mlir::dyn_cast(type)) { return tensor_type.clone(element_type); } return type; @@ -874,23 +910,23 @@ bool IsQuantizedCompositeFunction(func::CallOp call_op) { bool has_quantized_types = false; for (Value operand : call_op.getOperands()) { - if (const TensorType type = operand.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (const TensorType type = mlir::dyn_cast(operand.getType())) { + if (mlir::isa(type.getElementType())) { return false; } - if (type.getElementType() - .isa()) { + if (mlir::isa( + type.getElementType())) { has_quantized_types = true; } } } for (const Value result : call_op.getResults()) { - if (const auto type = result.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (const auto type = mlir::dyn_cast(result.getType())) { + if (mlir::isa(type.getElementType())) { return false; } - if (type.getElementType() - .isa()) { + if (mlir::isa( + type.getElementType())) { has_quantized_types = true; } } @@ -919,7 +955,7 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { ->has_same_scale_requirement) { for (const OpResult result : preceding_op->getResults()) { const Type element_type = getElementTypeOrSelf(result.getType()); - if (element_type.isa()) { + if (mlir::isa(element_type)) { return true; } } @@ -947,7 +983,7 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { ->has_same_scale_requirement) { for (Value operand : following_op->getOperands()) { const Type element_type = getElementTypeOrSelf(operand.getType()); - if (element_type.isa()) { + if (mlir::isa(element_type)) { return true; } } @@ -958,20 +994,6 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { return false; } -template -class QuantizeWeightOnlyOpPattern : public EntryFuncBodyQuantizationPattern { - public: - explicit QuantizeWeightOnlyOpPattern( - const bool enable_per_channel_quantized_weight) {} - - LogicalResult match(func::FuncOp entry_func_op) const override { - return MatchGemmStyleOp(entry_func_op); - } - - void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, - PatternRewriter& rewriter) const override {} -}; - // Compute heavy patterns should be quantized for both server and ODML targets. // Most patterns here are useful when quantized since they are compute heavy // or memory bound. @@ -979,13 +1001,18 @@ void PopulateCommonQuantizationPatterns( MLIRContext& ctx, RewritePatternSet& patterns, const bool enable_per_channel_quantized_weight) { patterns.add>( - ctx, enable_per_channel_quantized_weight, /*enable_weight_only=*/false); + ctx, enable_per_channel_quantized_weight); patterns.add>( - ctx, enable_per_channel_quantized_weight, /*enable_weight_only=*/false); + ctx, enable_per_channel_quantized_weight); + patterns + .add>>( + ctx, enable_per_channel_quantized_weight); + patterns + .add>>( + ctx, enable_per_channel_quantized_weight); // TODO: b/307620772 - Per-channel quantization for gather. patterns.add>>( - ctx, /*enable_per_channel_quantized_weight=*/false, - /*enable_weight_only=*/false); + ctx, /*enable_per_channel_quantized_weight=*/false); // Populate pattern for quantization of ops with regions such as // `stablehlo.reduce_window` op. patterns.add(ctx); @@ -994,16 +1021,7 @@ void PopulateCommonQuantizationPatterns( void PopulateAllQuantizablePatterns(MLIRContext& ctx, RewritePatternSet& patterns) { patterns.add>>( - ctx, /*enable_per_channel_quantized_weight=*/false, - /*enable_weight_only=*/false); -} - -void PopulateQuantizeWeightOnlyPatterns(MLIRContext& ctx, - RewritePatternSet& patterns) { - patterns.add, - XlaCallModuleOpToCallOp>( - ctx, /*enable_per_channel_quantized_weight*/ false, - /*enable_weight_only=*/true); + ctx, /*enable_per_channel_quantized_weight=*/false); } } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h index 67eb267c1d9037..c07314d6cff6cf 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" #include "tensorflow/core/framework/types.pb.h" @@ -59,18 +60,8 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op); // quantization parameters are annotated by the QuantizeOp/DequantizeOp pairs. // Each matched pattern are rewritten by its quantized alternatives. // -// The concrete pattern, extends from this base pattern, can specify whether it -// allows weight-only quantization. If it is allowed, for operand/result that is -// not adjacent to dequantize/quantize op, it remains as float. For -// operand/result that is adjacent to dequantize/quantize, it is quantized. -// Weight-only quantization can be used to generate both weight-only -// quantization and dynamic range quantization. The condition for allowing -// weight-only quantization or not for an op can be specified in the below -// function: -// -// static bool AllowWeightOnlyQuantization(Operation& op) -// -// This is a templatized `OpRewritePattern`. +// Quantization method is determined by the `_quantization_method` attributes +// attached to each quantizable units. // // Template constraints are imposed as follows: // @@ -159,18 +150,22 @@ class StableHloQuantizationPattern : public OpRewritePattern { return failure(); } + const bool weight_only_quantizable = + IsWeightOnlyQuantizableOp(*candidate_op); + // Collect all the quantized inputs and "clone" the matched op by these // inputs. SmallVector inputs; inputs.reserve(candidate_op->getNumOperands()); for (auto operand : candidate_op->getOperands()) { Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (mlir::isa(operand_type)) { inputs.push_back(operand); continue; } - auto ele_type = operand.getType().cast().getElementType(); + auto ele_type = + mlir::cast(operand.getType()).getElementType(); if (auto dq_op = dyn_cast_or_null(operand.getDefiningOp())) { inputs.push_back(dq_op.getOperand()); @@ -178,8 +173,7 @@ class StableHloQuantizationPattern : public OpRewritePattern { // If the operand is an integer tensor, then it doesn't require the // DequantizeOp in the pattern. inputs.push_back(operand); - } else if (static_cast(this) - ->AllowWeightOnlyQuantization(*candidate_op)) { + } else if (weight_only_quantizable) { inputs.push_back(operand); } else { return failure(); @@ -197,13 +191,13 @@ class StableHloQuantizationPattern : public OpRewritePattern { Type result_type = result.getType(); // Add this to the test coverage once we create test ops with none type // results. - if (result_type.isa()) { + if (mlir::isa(result_type)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result_type); continue; } Type result_ele_type = - result.getType().cast().getElementType(); + mlir::cast(result.getType()).getElementType(); // If the user is the QuantizeOp, it must be the only user. if (result.hasOneUse() && isa(*result.user_begin())) { auto user = cast(*result.user_begin()); @@ -215,8 +209,7 @@ class StableHloQuantizationPattern : public OpRewritePattern { // D op in the pattern. outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result.getType()); - } else if (static_cast(this) - ->AllowWeightOnlyQuantization(*candidate_op)) { + } else if (weight_only_quantizable) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result.getType()); } else { @@ -260,10 +253,6 @@ void PopulateCommonQuantizationPatterns( void PopulateAllQuantizablePatterns(MLIRContext& ctx, RewritePatternSet& patterns); -// Populates pattern weight-only quantization. -void PopulateQuantizeWeightOnlyPatterns(MLIRContext& ctx, - RewritePatternSet& patterns); - } // namespace mlir::quant::stablehlo #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index 0000057402886f..86dbae8e4181f9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -77,35 +77,14 @@ struct StableHloQuantizationReverse quantfork::QuantizeCastOp>(ctx) {} }; -bool IsHybridQuantizableOp(Operation& op) { - auto call_op = cast(op); - if (call_op == nullptr) return false; - StringRef entry_function_name = GetEntryFunctionName(call_op); - return entry_function_name.contains("conv") || - entry_function_name.contains("dot_general"); -} - -// Quantization rewrite pattern using DQ as the root op. -struct StableHloQuantizationWeightOnly - : public StableHloQuantizationBase { - explicit StableHloQuantizationWeightOnly(MLIRContext* ctx) - : StableHloQuantizationBase(ctx) {} - - static bool AllowWeightOnlyQuantization(Operation& op) { - return IsHybridQuantizableOp(op); - } -}; - class QuantizePass : public impl::QuantizePassBase { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizePass) using impl::QuantizePassBase::QuantizePassBase; - explicit QuantizePass(const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) { + explicit QuantizePass(const bool enable_per_channel_quantized_weight) { enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; - enable_weight_only_ = enable_weight_only; } private: @@ -118,10 +97,6 @@ void QuantizePass::runOnOperation() { RewritePatternSet patterns(&ctx); patterns.add(&ctx); - if (enable_weight_only_) { - patterns.add(&ctx); - PopulateQuantizeWeightOnlyPatterns(ctx, patterns); - } PopulateCommonQuantizationPatterns(ctx, patterns, enable_per_channel_quantized_weight_); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index 1efc5d40c7ce20..a713f5501b271d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "absl/log/log.h" #include "absl/status/status.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project // IWYU pragma: keep @@ -26,8 +24,6 @@ limitations under the License. #include "mlir/Support/TypeID.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" // IWYU pragma: keep -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" @@ -55,10 +51,8 @@ class QuantizeCompositeFunctionsPass QuantizeCompositeFunctionsPass>::QuantizeCompositeFunctionsPassBase; explicit QuantizeCompositeFunctionsPass( - const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) { + const bool enable_per_channel_quantized_weight) { enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; - enable_weight_only_ = enable_weight_only; } private: @@ -80,9 +74,10 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { // Change this to user-given bit width once we have custom configuration. options.bit_width_ = 8; - if (enable_weight_only_) { - pm.addNestedPass(createInsertWeightParamPass()); - } + // Insert quantization parameters for weights for ops with `weight_only_ptq` + // attribute. + pm.addNestedPass(createInsertWeightParamPass()); + // PrepareQuantizePass uses SymbolTable to fetch relevant GEMM ops for // determining quantization attributes. This requires module-level context. pm.addPass(createPrepareQuantizePass(options)); @@ -90,7 +85,7 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { QuantizePassOptions quantize_options; quantize_options.enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight_; - quantize_options.enable_weight_only_ = enable_weight_only_; + // QuantizePass modifies FuncOps referenced outside of its given scope // and therefore requires a module-level context. pm.addPass(createQuantizePass(quantize_options)); @@ -113,10 +108,6 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { !pm_run_status.ok()) { signalPassFailure(); } - - // Emit human-readable quantization report. - const QuantizationReport report(module_op); - report.Print(); } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc index 95f150d683c57b..e0469cc8d14032 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc @@ -111,7 +111,7 @@ class QuantizeWeight : public OpRewritePattern { QuantizationUnits GetQuantizableOps(ConstantOp op) const { // Non-float tensors do not need quantization. QuantizationUnits quantizable_ops; - const ShapedType type = op.getType().dyn_cast(); + const ShapedType type = mlir::dyn_cast(op.getType()); if (!type || !type.getElementType().isF32()) return quantizable_ops; const Value value = op.getResult(); @@ -150,7 +150,7 @@ class QuantizeWeight : public OpRewritePattern { } TensorType old_result_type = - op.getResult().getType().dyn_cast(); + mlir::dyn_cast(op.getResult().getType()); const FloatType quantized_type = FloatType::getF16(op.getContext()); const ShapedType new_result_type = old_result_type.clone(quantized_type); @@ -184,7 +184,7 @@ class QuantizeWeight : public OpRewritePattern { // Get types. const Type old_result_type = op.getResult().getType(); const ShapedType new_result_type = - convert_op.getType().dyn_cast(); + mlir::dyn_cast(convert_op.getType()); // Proceeds only if the converting is to float16. if (!new_result_type.getElementType().isF16()) continue; @@ -192,7 +192,7 @@ class QuantizeWeight : public OpRewritePattern { // Convert values. std::vector new_values; const DenseFPElementsAttr value_attr = - op.getValue().cast(); + mlir::cast(op.getValue()); new_values.reserve(value_attr.getNumElements()); for (const float value : value_attr.getValues()) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc index 6ed82c125b0be9..e1b4adb013684c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc @@ -163,7 +163,7 @@ void CreateXlaCallModuleOp(ValueRange inputs, ValueRange outputs, SmallVector shape_attrs; for (const Type result_type : result_types) { shape_attrs.push_back( - tf_type::ShapeAttr::get(ctx, result_type.cast())); + tf_type::ShapeAttr::get(ctx, mlir::cast(result_type))); } const auto empty_array_attr = ArrayAttr::get(ctx, {}); // TODO: b/310291615 - find a better way for platform support. @@ -502,7 +502,7 @@ void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass:: SymbolTable symbol_table(module_op); for (auto call_op : main_func.getOps()) { func_ops.push_back(dyn_cast_or_null(symbol_table.lookup( - call_op.getFAttr().cast().getValue()))); + mlir::cast(call_op.getFAttr()).getValue()))); } for (auto call_op : main_func.getOps()) { func_ops.push_back( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index df5252b986adf5..0999d37da524c2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -30,6 +30,7 @@ package( pytype_strict_library( name = "quantization", srcs = ["quantization.py"], + visibility = ["//visibility:public"], deps = [ ":pywrap_quantization", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_py", @@ -45,6 +46,10 @@ pytype_strict_library( # testonly = 1, # srcs = ["integration_test/quantize_model_test_base.py"], # tags = ["no_pip"], +# visibility = [ +# "//learning/brain/mlir/quantization/stablehlo:__subpackages__", +# "//tensorflow/compiler/mlir/quantization:__subpackages__", +# ], # deps = [ # "//third_party/py/mlir:ir", # "//third_party/py/mlir:stablehlo_dialect", @@ -62,6 +67,7 @@ pytype_strict_library( # "//tensorflow/python/ops:nn_ops", # "//tensorflow/python/ops:variables", # "//tensorflow/python/platform:client_testlib", +# "//tensorflow/python/platform:tf_logging", # "//tensorflow/python/saved_model:load", # "//tensorflow/python/saved_model:loader", # "//tensorflow/python/saved_model:save", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py index f65c56bc577742..ab0fb1d5662bba 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import os import re from typing import Mapping, Optional, Sequence from absl.testing import parameterized import numpy as np +from google.protobuf import text_format from tensorflow.compiler.mlir.quantization.common.python import testing from tensorflow.compiler.mlir.quantization.stablehlo import quantization_config_pb2 as qc from tensorflow.compiler.mlir.quantization.stablehlo.python import quantization @@ -145,7 +147,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # done in MLIR level. # Tests that the quantized graph outputs similar values. The rtol and atol # values are arbitrary. - self.assertAllClose(new_outputs, expected_outputs, rtol=0.03, atol=0.2) + self.assertAllClose(new_outputs, expected_outputs, rtol=0.3, atol=0.2) # Due to other meta data, the compression is not exactly 1/4. self.assertLess( @@ -575,6 +577,114 @@ def data_gen() -> repr_dataset.RepresentativeDataset: 0.65, ) + def test_reuse_calibration_data(self): + _, y_shape, bias_shape, x_signature, y_signature = ( + self._prepare_sample_einsum_datashapes('abc,cde->abde', use_bias=True) + ) + + self._create_einsum_model( + self._input_saved_model_path, + 'abc,cde->abde', + y_shape, + x_signature, + y_signature, + bias_shape, + ) + + # Generate model input data. + rng = np.random.default_rng(seed=42) + input_data = ops.convert_to_tensor( + rng.uniform(low=0.0, high=1.0, size=x_signature).astype('f4') + ) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(100): + yield { + 'x': ops.convert_to_tensor( + np.random.uniform(low=0.0, high=1.0, size=x_signature).astype( + 'f4' + ) + ), + } + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + calibration_data_dir = self.create_tempdir('calibration_data').full_path + config = qc.QuantizationConfig( + static_range_ptq_preset=qc.StaticRangePtqPreset( + representative_datasets=[ + qc.RepresentativeDatasetConfig( + tf_record=qc.TfRecordFile(path=dataset_path) + ) + ] + ), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + calibration_options=qc.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX, + calibration_data_dir=calibration_data_dir, + ), + ) + + # Run quantization the first time, calibration is expected to be run. + with self.assertLogs(level='INFO') as info_logs: + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + self.assertTrue( + self._any_log_contains( + 'Calibration step is executed in graph mode.', + info_logs.records, + ) + ) + module_str = self._extract_first_xla_call_module_op( + self._output_saved_model_path + ) + self.assertTrue( + re.search('stablehlo.dot_general.*xi8>.*xi8>.*xi32>', module_str) + ) + + # Run quantization the first time, calibration is expected to be skipped. + output_saved_model_path_2 = self.create_tempdir('output2').full_path + with self.assertLogs(level='INFO') as info_logs: + quantization.quantize_saved_model( + self._input_saved_model_path, + output_saved_model_path_2, + config, + ) + self.assertFalse( + self._any_log_contains( + 'Calibration step is executed in graph mode.', + info_logs.records, + ) + ) + module_str = self._extract_first_xla_call_module_op( + output_saved_model_path_2 + ) + self.assertTrue( + re.search('stablehlo.dot_general.*xi8>.*xi8>.*xi32>', module_str) + ) + + # Expect both quantized model to produce the same results. + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + new_outputs_1 = root.signatures['serving_default']( + x=ops.convert_to_tensor(input_data) + ) + + root = load.load(output_saved_model_path_2) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + new_outputs_2 = root.signatures['serving_default']( + x=ops.convert_to_tensor(input_data) + ) + + self.assertAllClose(new_outputs_1, new_outputs_2) + @parameterized.named_parameters( ('use_constant_with_int32_input', np.int32, False), ('use_variable_with_int32_input', np.int32, True), @@ -897,7 +1007,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # be exactly the same. Indirectly proves that the `FunctionNameMatcherSpec` # with regex '.*invalid_function_name.*' did not match the quantizable unit. self.assertAllClose(new_outputs, expected_outputs, rtol=0.04) - self.assertNotAllClose(new_outputs, expected_outputs, rtol=0.00001) + self.assertNotAllClose(new_outputs, expected_outputs, 1e-7) # Due to other meta data, the compression is not exactly 1/4. self.assertLess( @@ -907,6 +1017,72 @@ def data_gen() -> repr_dataset.RepresentativeDataset: 0.4, ) + def test_save_quantization_report_file(self): + """Tests that the quantization report file is created. + + Also test that it is populated with textproto of `QuantizationResults`. + """ + input_shape = (1, 16) + filter_shape = (16, 3) + self._create_matmul_model( + input_shape, + filter_shape, + self._input_saved_model_path, + ) + + rng = np.random.default_rng(seed=42) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(100): + yield { + 'input_tensor': rng.uniform( + low=0.0, high=1.0, size=input_shape + ).astype(np.float32) + } + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + report_file_path = self.create_tempfile('report.txtpb').full_path + config = qc.QuantizationConfig( + static_range_ptq_preset=qc.StaticRangePtqPreset( + representative_datasets=[ + qc.RepresentativeDatasetConfig( + tf_record=qc.TfRecordFile(path=dataset_path) + ) + ] + ), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + report_file_path=report_file_path, + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + # Test the contents of the report file, which is a textproto of + # `QuantizationResults`. + self.assertTrue(os.path.exists(report_file_path)) + with open(report_file_path, 'r') as f: + quantization_results_textpb = f.read() + + results = qc.QuantizationResults() + text_format.Parse(quantization_results_textpb, results) + + self.assertProtoEquals( + expected_message_maybe_ascii=r""" + results { + quantizable_unit { name: "composite_dot_general_fn_1" } + method { static_range_ptq {} } + } + """, + message=results, + ) + @test_util.run_all_in_graph_and_eager_modes class CalibrationOptionsTest(quantize_model_test_base.QuantizedModelTest): @@ -931,7 +1107,7 @@ class CalibrationOptionsTest(quantize_model_test_base.QuantizedModelTest): 'calibration_options': qc.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, calibration_parameters=qc.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + num_bins=10, ), ), }, @@ -939,7 +1115,7 @@ class CalibrationOptionsTest(quantize_model_test_base.QuantizedModelTest): 'calibration_options': qc.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, calibration_parameters=qc.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + num_bins=10, ), ), }, @@ -947,7 +1123,7 @@ class CalibrationOptionsTest(quantize_model_test_base.QuantizedModelTest): 'calibration_options': qc.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, calibration_parameters=qc.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + num_bins=10, ), ), }, @@ -955,7 +1131,7 @@ class CalibrationOptionsTest(quantize_model_test_base.QuantizedModelTest): 'calibration_options': qc.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, calibration_parameters=qc.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + num_bins=10, ), ), }, @@ -1212,9 +1388,8 @@ def test_conv_weight_only_model( self._output_saved_model_path ) - # Tests that the output graph contains subtract and multiply for + # Tests that the output graph contains multiply op for symmetric # dequantization. - self.assertTrue(re.search('stablehlo.subtract', module_str)) self.assertTrue(re.search('stablehlo.multiply', module_str)) # Tests that the output graph contains float dot_general. self.assertTrue( @@ -1357,6 +1532,58 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # Check add is not quantized. self.assertTrue(re.search(r'stablehlo.add.*f32>', module_str), module_str) + def test_save_quantization_report_file(self): + """Tests that the quantization report file is created. + + Also test that it is populated with textproto of `QuantizationResults`. + """ + + input_shape = (1, 3, 4, 3) + filter_shape = (2, 3, 3, 2) + self._create_conv2d_model( + input_shape, + filter_shape, + self._input_saved_model_path, + ) + + report_file_path = self.create_tempfile('report.txtpb').full_path + config = qc.QuantizationConfig( + weight_only_ptq_preset=qc.WeightOnlyPtqPreset(), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + report_file_path=report_file_path, + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + # Test the contents of the report file, which is a textproto of + # `QuantizationResults`. + self.assertTrue(os.path.exists(report_file_path)) + with open(report_file_path, 'r') as f: + quantization_results_textpb = f.read() + + results = qc.QuantizationResults() + text_format.Parse(quantization_results_textpb, results) + + self.assertProtoEquals( + expected_message_maybe_ascii=r""" + results { + quantizable_unit { name: "composite_conv_fn_1" } + method { + weight_only_ptq { + input_quantized_types { + key: 1 + value { dimension_specs {} } + } + } + } + } + """, + message=results, + ) + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py index 31c53a4cf20fe9..fef1784fec9370 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py @@ -33,11 +33,13 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import load from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import save as saved_model_save from tensorflow.python.types import core + FUNC_ALIAS = 'some_alias' @@ -164,6 +166,27 @@ def matmul(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: ) return model + def _any_log_contains( + self, substring: str, log_record_list: List['logging.LogRecord'] + ) -> bool: + """Returns True if any of the log contains a given substring. + + Args: + substring: A piece of string to check whether it exists in the log + message. + log_record_list: A list of `absl.logging.LogRecord`s. + + Returns: + True if and only if the substring exists in any of the log in + `log_record_list`. + """ + return any( + map( + lambda log_record: substring in str(log_record.message), + log_record_list, + ) + ) + def _create_matmul_and_same_scale_model( self, input_shape: Sequence[int], diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.cc b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.cc index 3b5ece120bdeb0..517bd117348072 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" namespace stablehlo::quantization::pywrap { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto index 81f2ff3686fbbe..49e8161df3a749 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto @@ -77,8 +77,8 @@ message StaticRangePtqPreset { bool enable_full_int_quantization = 3; } -// Applies int8 per-tensor weight-only post-training quantization for all -// dot_general op. +// Applies int8 per-channel weight-only post-training quantization for all +// dot_general and convolution ops. message WeightOnlyPtqPreset {} // Metadata specific to the input TensorFlow SavedModel, which may be required @@ -131,10 +131,18 @@ message QuantizationResults { repeated QuantizationResult results = 1; } +// Signals per-channel quantization. When dimension is not specified, StableHLO +// quantizer determines the quantization dimension to be output feature +// dimension for convolution and first non-batching, non-contracting dimension +// for dot_general. message QuantizedDimension { - int32 dimension = 1; // Should be less than the rank of the quantized tensor. + // Should be less than the rank of the quantized tensor. + optional int32 dimension = 1; } +// Signals quantization type to be per-tensor. +message PerTensor {} + // Corresponds to StableHLO's `QuantizedTensorElementType`. Type parameters such // as `QuantizationParameters` is omitted because they are determined during // quantization. @@ -144,13 +152,17 @@ message QuantizedDimension { // Currently only supports specifying quantization granularity (e.g. for // per-channel quantization). // TODO: b/331144430 - Support specifying storage types. +// Next ID: 3 message QuantizedType { // Specifies the granularity of quantization parameters for each dimension of // a quantized tensor. If specified, per-channel quantization is applied. If // not specified, per-tensor quantization is applied. // TODO: Make it a `repeated` field to be able to express multi-channel / // sub-channel quantization. - QuantizedDimension dimension_specs = 1; + oneof type { + QuantizedDimension dimension_specs = 1; + PerTensor per_tensor = 2; + } } // A quantization method representing "do not quantize". Mostly used for @@ -266,7 +278,7 @@ message DebuggerConfig { } // Defines various calibration options. -// Next ID: 4 +// Next ID: 6 message CalibrationOptions { // Configurations for calibration methods. // Next ID: 7 @@ -296,10 +308,8 @@ message CalibrationOptions { // Parameters required for calibration. // Next ID: 4 message CalibrationParameters { - // The number of bins when histogram is initialized. It can be increased - // because histogram is dynamically expanded by sample inputs. - // initial_num_bins is 256 by default. - int32 initial_num_bins = 1; + // The number of histogram bins. Default to 512. + int32 num_bins = 1; // min_percentile is only used in HISTOGRAM_PERCENTILE. // min_percentile is 0.001 by default. float min_percentile = 2; @@ -321,11 +331,19 @@ message CalibrationOptions { // Configures representative dataset. Each item corresponds to a // representative dataset used to calibrate a function. repeated RepresentativeDatasetConfig representative_datasets = 3; + + // The path to save calibration statistics data. If not set, use a temporary + // directory. + string calibration_data_dir = 4; + + // Whether to reuse the existing calibration data in `calibration_data_dir`. + // Default to False. + bool force_regenerate_calibration_data = 5; } // Quantization configuration for StableHLO Quantizer. This is the primary // message containing all configurable options. -// Next ID: 8 +// Next ID: 9 message QuantizationConfig { // Config presets provide predefined popular or common quantization specs. // Lightweight users may choose one of the presets for quick experiments. Each @@ -354,4 +372,9 @@ message QuantizationConfig { // activation of static range quantization (SRQ). Quantization calibration // method is set to MIN_MAX by default. CalibrationOptions calibration_options = 6; + + // Path to file to save the quantization report, which is essentially a + // textproto rendering of `QuantizationResults`. If not set, the report will + // only be emitted to stdout. + optional string report_file_path = 8; } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir index 2f149281fbd0be..61f4b27e66af90 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir @@ -8,9 +8,9 @@ // int ops. func.func @main(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> { %0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32> - %1:4 = "tf.CustomAggregator"(%arg0) <{id = "1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 0.999992311 : f32, max_percentile = 0.000000e+00 : f32, min = 7.547870e-07 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %1:4 = "tf.CustomAggregator"(%arg0) <{id = "1", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> {min = 7.547870e-07 : f32, max = 0.999992311 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) %2 = "tf.XlaCallModule"(%1#0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> - %3:4 = "tf.CustomAggregator"(%2) <{id = "2"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 5.3033524 : f32, max_percentile = 0.000000e+00 : f32, min = -3.5216827 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + %3:4 = "tf.CustomAggregator"(%2) <{id = "2", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> {min = -17.5216827 : f32, max = 18.3033524 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) return %3#0 : tensor<1x3xf32> } func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { @@ -36,9 +36,9 @@ func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: func.func @main_no_unpack(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> { %0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32> - %1:4 = "tf.CustomAggregator"(%arg0) <{id = "1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 0.999992311 : f32, max_percentile = 0.000000e+00 : f32, min = 7.547870e-07 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %1:4 = "tf.CustomAggregator"(%arg0) <{id = "1", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> {device = "", max = 0.999992311 : f32, min = 7.547870e-07 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) %2 = "tf.XlaCallModule"(%1#0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> - %3:4 = "tf.CustomAggregator"(%2) <{id = "2"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 5.3033524 : f32, max_percentile = 0.000000e+00 : f32, min = -3.5216827 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + %3:4 = "tf.CustomAggregator"(%2) <{id = "2", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> {device = "", max = 18.3033524 : f32, min = -17.5216827 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) return %3#0 : tensor<1x3xf32> } func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { @@ -47,7 +47,7 @@ func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: } // CHECK-NO-UNPACK-LABEL: func.func @main_no_unpack // CHECK-NO-UNPACK-SAME: (%[[ARG_0:.+]]: tensor<1x1024xf32>) -> tensor<1x3xf32> -// CHECK-NO-UNPACK-DAG: %[[CONST:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform:f32:1, {{.*}}>> +// CHECK-NO-UNPACK-DAG: %[[CONST:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1024x3xi8>}> : () -> tensor<1024x3x!quant.uniform:f32:1, {{.*}}>> // CHECK-NO-UNPACK: %[[QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x1024xf32>) -> tensor<1x1024x!quant.uniform> // CHECK-NO-UNPACK: %[[DOT:.+]] = stablehlo.dot_general %[[QUANTIZE_0]], %[[CONST]] // CHECK-NO-UNPACK: %[[QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[DOT]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir index 954323af9ef7ad..0c5e7a7cab09f2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir @@ -8,10 +8,10 @@ func.func @main(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { } // CHECK: @main(%[[ARG_0:.+]]: tensor<1x4xf32>) -> tensor<1x3xf32> // CHECK-DAG: %[[CST:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> -// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {{.*}} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) +// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG_0]]) <{calibration_method = 1 : i32, id = "composite_dot_general_fn_1_arg_0_calibration_method_1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) // CHECK: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[CST]]) // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" -// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {{.*}} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) +// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{calibration_method = 1 : i32, id = "composite_dot_general_fn_1_calibration_method_1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) // CHECK: return %[[CUSTOM_AGGREGATOR_1]] : tensor<1x3xf32> // CHECK: } // CHECK: } @@ -28,10 +28,10 @@ func.func @serving_default(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { } // CHECK: @serving_default(%[[ARG_0:.+]]: tensor<1x4xf32>) -> tensor<1x3xf32> // CHECK-DAG: %[[CST:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> -// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {{.*}} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) +// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG_0]]) <{calibration_method = 1 : i32, id = "composite_dot_general_fn_1_arg_0_calibration_method_1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) // CHECK: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[CST]]) // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" -// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {{.*}} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) +// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{calibration_method = 1 : i32, id = "composite_dot_general_fn_1_calibration_method_1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) // CHECK: return %[[CUSTOM_AGGREGATOR_1]] : tensor<1x3xf32> // CHECK: } // CHECK: } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/defer_activation_transpose.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/defer_activation_transpose.mlir index 96b270f8b888f9..d9db49de957ac9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/defer_activation_transpose.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/defer_activation_transpose.mlir @@ -118,14 +118,14 @@ func.func @reduce_window_max_activation_transpose(%arg0: tensor<1x16x16x4xf32>) // Check that the body is not modified. // CHECK: %[[REDUCE_WINDOW:.+]] = "stablehlo.reduce_window"(%[[ARG]], %[[INIT_VALUE_CONST]]) +// CHECK: <{window_dimensions = array, window_strides = array}> // CHECK: ^bb0(%[[REDUCE_ARG_0:.+]]: tensor, %[[REDUCE_ARG_1:.+]]: tensor): // CHECK: %[[MAX:.+]] = stablehlo.maximum %[[REDUCE_ARG_0]], %[[REDUCE_ARG_1]] // CHECK: stablehlo.return %[[MAX]] // Check that the attributes window_dimensions & window_strides are also // permutated to match the new input shape. -// CHECK: {window_dimensions = array, window_strides = array} -// CHECK-SAME: (tensor<1x16x16x4xf32>, tensor) -> tensor<1x8x8x4xf32> +// CHECK: (tensor<1x16x16x4xf32>, tensor) -> tensor<1x8x8x4xf32> // Check that a `stablehlo.transpose` is added to the result to match the shape // of the users. @@ -162,6 +162,7 @@ func.func @reduce_window_max_activation_transpose_explicit_optional_attrs( // Check that the body is not modified. // CHECK: %[[REDUCE_WINDOW:.+]] = "stablehlo.reduce_window"(%[[ARG]], %[[INIT_VALUE_CONST]]) +// CHECK: <{base_dilations = array, window_dilations = array, window_dimensions = array, window_strides = array}> // CHECK: ^bb0(%[[REDUCE_ARG_0:.+]]: tensor, %[[REDUCE_ARG_1:.+]]: tensor): // CHECK: %[[MAX:.+]] = stablehlo.maximum %[[REDUCE_ARG_0]], %[[REDUCE_ARG_1]] // CHECK: stablehlo.return %[[MAX]] @@ -169,8 +170,7 @@ func.func @reduce_window_max_activation_transpose_explicit_optional_attrs( // Check that the attributes window_dimensions & window_strides along with // optional attributes base_dilations and window_dilations are also permutated // to match the new input shape. -// CHECK: {base_dilations = array, window_dilations = array, window_dimensions = array, window_strides = array} -// CHECK-SAME: (tensor<1x16x16x4xf32>, tensor) -> tensor<1x15x15x4xf32> +// CHECK: (tensor<1x16x16x4xf32>, tensor) -> tensor<1x15x15x4xf32> // Check that a `stablehlo.transpose` is added to the result to match the shape // of the users. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_calibration_statistics_saver.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_calibration_statistics_saver.mlir new file mode 100644 index 00000000000000..f80ce73ff88bf1 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_calibration_statistics_saver.mlir @@ -0,0 +1,219 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -mlir-disable-threading -stablehlo-insert-calibration-statistics-saver | FileCheck %s + +func.func @serving_default(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x2x2x2xf32>) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}} { + %cst = "tf.Const"() <{value = dense<[[[[-0.891899645, 0.392044574], [0.77720493, 1.31188095], [0.255048186, 2.700150e+00]], [[-1.08111858, -0.406604826], [-0.298575521, -2.25356531], [-1.00201964, 2.54532099]], [[-1.34911358, 0.279911458], [-0.868258893, -1.36708188], [0.866317451, -2.05804896]]], [[[-0.591397941, 0.331505477], [0.715151429, 2.64073896], [1.27163255, 0.206143498]], [[0.474211812, 1.45044816], [0.119936548, 2.54149938], [-0.939900994, 0.438387245]], [[-1.12486279, -1.09022558], [0.82202208, 1.04652023], [1.30316162, 2.62054276]]]]> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 5 : i32, id = "0", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) + %0 = "tf.Conv2D"(%output, %cst) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> + %output_1, %min_2, %max_3, %histogram_4 = "tf.CustomAggregator"(%0) <{calibration_method = 5 : i32, id = "1", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x2x2x2xf32>) -> (tensor<1x2x2x2xf32>, tensor, tensor, tensor<512xi64>) + %1 = "tf.Identity"(%output_1) {device = ""} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %1 : tensor<1x2x2x2xf32> +} +// CHECK-LABEL: @serving_default +// CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], %[[MIN_O:.*]], %[[MAX_O:.*]], %[[HISTOGRAM_0:.*]] = "tf.CustomAggregator" +// CKECK-SAME: <{calibration_method = 5 : i32, id = "0", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) +// CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], %[[MIN_1:.*]], %[[MAX_1:.*]], %[[HISTOGRAM_1:.*]] = "tf.CustomAggregator" +// CKECK-SAME: <{calibration_method = 5 : i32, id = "1", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) +// CHECK: "tf.CalibrationStatisticsSaver"(%[[MIN_O]], %[[MAX_O]], %[[HISTOGRAM_0]], %[[MIN_1]], %[[MAX_1]], %[[HISTOGRAM_1]]) +// CHECK-SAME: <{calibration_methods = [5 : i32, 5 : i32], ids = ["0", "1"], output_file_path = "serving_default_0.pb"}> : (tensor, tensor, tensor<512xi64>, tensor, tensor, tensor<512xi64>) -> () +// CHECK: return + +// ----- + +// No CustomAggregator ops exist. +func.func private @composite_conv2d_with_bias_and_relu6_fn_1(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<1x2x2x2xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> : (tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x2x2x2xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %2 : tensor<1x2x2x2xf32> +} +// CHECK-LABEL: @composite_conv2d_with_bias_and_relu6_fn_1 +// CHECK-NOT: "tf.CalibrationStatisticsSaver" + +// ----- + +// Check the IfOp is set to stateful. +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1833 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: func.func @serving_default + // CHECK: "tf.If" + // CHECK-SAME: is_stateless = false + func.func @serving_default(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi32>}> {device = ""} : () -> tensor<2xi32> + %cst_0 = "tf.Const"() <{value = dense<1.000000e+01> : tensor}> {device = ""} : () -> tensor + %0 = "tf.Sum"(%arg0, %cst) <{keep_dims = false}> {device = ""} : (tensor<1x4xf32>, tensor<2xi32>) -> tensor + %1 = "tf.Greater"(%0, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %2:2 = "tf.If"(%1, %arg0) <{else_branch = @cond_false_80, is_stateless = true, then_branch = @cond_true_70}> {Tcond = i1, Tin = [f32], Tout = [i1, f32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], device = ""} : (tensor, tensor<1x4xf32>) -> (tensor, tensor<1x3xf32>) + %3 = "tf.Identity"(%2#1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @cond_false_80 + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "cond_false_80_0.pb" + func.func private @cond_false_80(%arg0: tensor<1x4xf32> {tf._user_specified_name = "x"}) -> (tensor, tensor<1x3xf32>) attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x4>], tf._original_func_name = "cond_false_8"} { + %cst = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0.117216609, 0.933735609, 0.0728900209]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.Identity"(%cst) {device = ""} : (tensor) -> tensor + %1 = "tf.PartitionedCall"(%output, %cst_1, %cst_0) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %output_2, %min_3, %max_4, %histogram_5 = "tf.CustomAggregator"(%1) <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + %2 = "tf.Identity"(%output_2) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %0, %2 : tensor, tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @cond_true_70 + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "cond_true_70_0.pb" + func.func private @cond_true_70(%arg0: tensor<1x4xf32> {tf._user_specified_name = "x"}) -> (tensor, tensor<1x3xf32>) attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x4>], tf._original_func_name = "cond_true_7"} { + %cst = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0.335351914, 0.084816426, -0.664676845]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.Identity"(%cst) {device = ""} : (tensor) -> tensor + %1 = "tf.PartitionedCall"(%output, %cst_1, %cst_0) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %output_2, %min_3, %max_4, %histogram_5 = "tf.CustomAggregator"(%1) <{calibration_method = 1 : i32, id = "3", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + %2 = "tf.Identity"(%output_2) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %0, %2 : tensor, tensor<1x3xf32> + } + + func.func private @composite_matmul_with_bias_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_matmul_with_bias_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} + +// ----- + +// Check the IfRegion is set to stateful. +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1833 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: func.func @serving_default + // CHECK: "tf.IfRegion" + // CHECK-SAME: is_stateless = false + + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "serving_default_0.pb" + + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "serving_default_1.pb" + + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "serving_default_2.pb" + func.func @serving_default(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() <{value = dense<1.000000e+01> : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi32>}> {device = ""} : () -> tensor<2xi32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %cst_2 = "tf.Const"() <{value = dense<[0.335351914, 0.084816426, -0.664676845]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_3 = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_4 = "tf.Const"() <{value = dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %cst_5 = "tf.Const"() <{value = dense<[0.117216609, 0.933735609, 0.0728900209]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.Sum"(%output, %cst_0) <{keep_dims = false}> {device = ""} : (tensor<1x4xf32>, tensor<2xi32>) -> tensor + %1 = "tf.Greater"(%0, %cst) {device = ""} : (tensor, tensor) -> tensor + %2:2 = "tf.IfRegion"(%1) <{_else_func_name = "cond_false_80", _then_func_name = "cond_true_70", is_stateless = true}> ({ + %4 = "tf.Identity"(%cst_3) {device = ""} : (tensor) -> tensor + %5 = "tf.PartitionedCall"(%output, %cst_1, %cst_2) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %output_6, %min_7, %max_8, %histogram_9 = "tf.CustomAggregator"(%5) <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + %6 = "tf.Identity"(%output_6) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + "tf.Yield"(%4, %6) {device = ""} : (tensor, tensor<1x3xf32>) -> () + }, { + %4 = "tf.Identity"(%cst_3) {device = ""} : (tensor) -> tensor + %5 = "tf.PartitionedCall"(%output, %cst_4, %cst_5) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %output_6, %min_7, %max_8, %histogram_9 = "tf.CustomAggregator"(%5) <{calibration_method = 1 : i32, id = "2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + %6 = "tf.Identity"(%output_6) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + "tf.Yield"(%4, %6) {device = ""} : (tensor, tensor<1x3xf32>) -> () + }) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = ""} : (tensor) -> (tensor, tensor<1x3xf32>) + %3 = "tf.Identity"(%2#1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + func.func private @composite_matmul_with_bias_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + func.func private @composite_matmul_with_bias_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1836 : i32}, tf_saved_model.semantics} { + func.func @main(%arg0: tensor<10x1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<10x1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = stablehlo.constant dense<0.000000e+00>: tensor<10x1024x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<10x1x1024xf32>) -> (tensor<10x1x1024xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.XlaCallModule"(%output, %cst) <{Sout = [#tf_type.shape<10x1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %output_0, %min_1, %max_2, %histogram_3 = "tf.CustomAggregator"(%0) <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<10x1x3xf32>) -> (tensor<10x1x3xf32>, tensor, tensor, tensor<0xi64>) + return %output_0 : tensor<10x1x3xf32> + } + // CHECK-LABEL: @main + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], %[[MIN_O:.*]], %[[MAX_O:.*]], %[[HISTOGRAM_0:.*]] = "tf.CustomAggregator" + // CKECK-SAME: <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], %[[MIN_1:.*]], %[[MAX_1:.*]], %[[HISTOGRAM_1:.*]] = "tf.CustomAggregator" + // CKECK-SAME: <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: "tf.CalibrationStatisticsSaver"(%[[MIN_O]], %[[MAX_O]], %[[HISTOGRAM_0]], %[[MIN_1]], %[[MAX_1]], %[[HISTOGRAM_1]]) + // CHECK-SAME: <{calibration_methods = [1 : i32, 1 : i32], ids = ["0", "1"], output_file_path = "main_0.pb"}> : (tensor, tensor, tensor<0xi64>, tensor, tensor, tensor<0xi64>) -> () + // CHECK: return + + func.func private @composite_dot_general_with_relu_fn_1(%arg0: tensor<10x1x1024xf32>, %arg1: tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %cst = stablehlo.constant dense<0.000000e+00> : tensor<10x1x3xf32> + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %1 = stablehlo.maximum %0, %cst : tensor<10x1x3xf32> + return %1 : tensor<10x1x3xf32> + } + // CHECK-LABEL: func.func private @composite_dot_general_with_relu_fn_1 + // CHECK-NOT: "tf.CalibrationStatisticsSaver" +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1836 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: func.func @main + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "main_0.pb" + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "main_1.pb" + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "main_2.pb" + func.func @main(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = stablehlo.constant dense<1.000000e+01> : tensor + %cst_0 = stablehlo.constant dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32> + %c = stablehlo.constant dense : tensor + %cst_1 = stablehlo.constant dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32> + %cst_2 = stablehlo.constant dense<-0.000000e+00> : tensor + %cst_3 = stablehlo.constant dense<[[0.335351914, 0.084816426, -0.664676845]]> : tensor<1x3xf32> + %cst_4 = stablehlo.constant dense<[[0.117216609, 0.933735609, 0.0728900209]]> : tensor<1x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) + %0 = stablehlo.reduce(%output init: %cst_2) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x4xf32>, tensor) -> tensor + %1 = stablehlo.compare GT, %0, %cst : (tensor, tensor) -> tensor + %2:2 = "stablehlo.if"(%1) ({ + %3 = "tf.XlaCallModule"(%output, %cst_0, %cst_3) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_same_shape_fn_2, _original_entry_function = "composite_dot_general_with_bias_same_shape_fn_2", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %output_5, %min_6, %max_7, %histogram_8 = "tf.CustomAggregator"(%3) <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + stablehlo.return %c, %output_5 : tensor, tensor<1x3xf32> + }, { + %3 = "tf.XlaCallModule"(%output, %cst_1, %cst_4) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_same_shape_fn_1, _original_entry_function = "composite_dot_general_with_bias_same_shape_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %output_5, %min_6, %max_7, %histogram_8 = "tf.CustomAggregator"(%3) <{calibration_method = 1 : i32, id = "2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + stablehlo.return %c, %output_5 : tensor, tensor<1x3xf32> + }) : (tensor) -> (tensor, tensor<1x3xf32>) + return %2#1 : tensor<1x3xf32> + } + func.func private @composite_dot_general_with_bias_same_shape_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_with_bias_same_shape_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_calibration_statistics_saver_with_skipping.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_calibration_statistics_saver_with_skipping.mlir new file mode 100644 index 00000000000000..97d546afe2b723 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_calibration_statistics_saver_with_skipping.mlir @@ -0,0 +1,47 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-insert-calibration-statistics-saver='aggregator-ops-to-ignore=skipping_id' | FileCheck %s + +func.func @serving_default(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x2x2x2xf32>) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}} { + %cst = "tf.Const"() <{value = dense<[[[[-0.891899645, 0.392044574], [0.77720493, 1.31188095], [0.255048186, 2.700150e+00]], [[-1.08111858, -0.406604826], [-0.298575521, -2.25356531], [-1.00201964, 2.54532099]], [[-1.34911358, 0.279911458], [-0.868258893, -1.36708188], [0.866317451, -2.05804896]]], [[[-0.591397941, 0.331505477], [0.715151429, 2.64073896], [1.27163255, 0.206143498]], [[0.474211812, 1.45044816], [0.119936548, 2.54149938], [-0.939900994, 0.438387245]], [[-1.12486279, -1.09022558], [0.82202208, 1.04652023], [1.30316162, 2.62054276]]]]> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 5 : i32, id = "skipping_id", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) + %0 = "tf.Conv2D"(%output, %cst) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> + %output_1, %min_2, %max_3, %histogram_4 = "tf.CustomAggregator"(%0) <{calibration_method = 5 : i32, id = "keeping_id", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x2x2x2xf32>) -> (tensor<1x2x2x2xf32>, tensor, tensor, tensor<512xi64>) + %1 = "tf.Identity"(%output_1) {device = ""} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %1 : tensor<1x2x2x2xf32> +} +// CHECK-LABEL: @serving_default +// CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], %[[MIN_O:.*]], %[[MAX_O:.*]], %[[HISTOGRAM_0:.*]] = "tf.CustomAggregator" +// CKECK-SAME: <{calibration_method = 5 : i32, id = "skipping_id", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) +// CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], %[[MIN_1:.*]], %[[MAX_1:.*]], %[[HISTOGRAM_1:.*]] = "tf.CustomAggregator" +// CKECK-SAME: <{calibration_method = 5 : i32, id = "keeping_id", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) +// CHECK: "tf.CalibrationStatisticsSaver"(%[[MIN_1]], %[[MAX_1]], %[[HISTOGRAM_1]]) +// CHECK-SAME: <{calibration_methods = [5 : i32], ids = ["keeping_id"], output_file_path = "serving_default_0.pb"}> : (tensor, tensor, tensor<512xi64>) -> () +// CHECK: return + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1836 : i32}, tf_saved_model.semantics} { + func.func @main(%arg0: tensor<10x1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<10x1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = stablehlo.constant dense<0.000000e+00>: tensor<10x1024x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "skipping_id", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<10x1x1024xf32>) -> (tensor<10x1x1024xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.XlaCallModule"(%output, %cst) <{Sout = [#tf_type.shape<10x1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %output_0, %min_1, %max_2, %histogram_3 = "tf.CustomAggregator"(%0) <{calibration_method = 1 : i32, id = "keeping_id", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<10x1x3xf32>) -> (tensor<10x1x3xf32>, tensor, tensor, tensor<0xi64>) + return %output_0 : tensor<10x1x3xf32> + } + // CHECK-LABEL: @main + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], %[[MIN_O:.*]], %[[MAX_O:.*]], %[[HISTOGRAM_0:.*]] = "tf.CustomAggregator" + // CKECK-SAME: <{calibration_method = 1 : i32, id = "skipping_id", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], %[[MIN_1:.*]], %[[MAX_1:.*]], %[[HISTOGRAM_1:.*]] = "tf.CustomAggregator" + // CKECK-SAME: <{calibration_method = 1 : i32, id = "keeping_id", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: "tf.CalibrationStatisticsSaver"(%[[MIN_1]], %[[MAX_1]], %[[HISTOGRAM_1]]) + // CHECK-SAME: <{calibration_methods = [1 : i32], ids = ["keeping_id"], output_file_path = "main_0.pb"}> : (tensor, tensor, tensor<0xi64>) -> () + // CHECK: return + + func.func private @composite_dot_general_with_relu_fn_1(%arg0: tensor<10x1x1024xf32>, %arg1: tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %cst = stablehlo.constant dense<0.000000e+00> : tensor<10x1x3xf32> + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %1 = stablehlo.maximum %0, %cst : tensor<10x1x3xf32> + return %1 : tensor<10x1x3xf32> + } + // CHECK-LABEL: func.func private @composite_dot_general_with_relu_fn_1 + // CHECK-NOT: "tf.CalibrationStatisticsSaver" +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir index 89ff96efecf471..6a194023dbbfc1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir @@ -1,14 +1,15 @@ // RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-insert-weight-param | FileCheck %s -// Test that q/dq pair is inserted between constant and XlaCallModule op -// with quantizable trait and function name containing conv. +// Test that q/dq pair with per-tensor quantization parameter is inserted +// between constant and XlaCallModule op with empty `weight_only_ptq` method +// and function name containing conv. -func.func @qdq_for_conv_weight(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> attributes {tf._original_func_name = "main_0"} { +func.func @qdq_for_conv_weight_empty(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> attributes {tf._original_func_name = "main_0"} { %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %0 = "tf.XlaCallModule"(%arg0, %cst) { Sout = [#tf_type.shape<1x2x2x2>], _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", - _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 @@ -16,25 +17,28 @@ func.func @qdq_for_conv_weight(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32 return %0 : tensor<1x2x2x2xf32> } -// CHECK-LABEL: func.func @qdq_for_conv_weight +// CHECK-LABEL: func.func @qdq_for_conv_weight_empty // CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> -// CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> -// CHECK: %[[DQ:.+]] = "quantfork.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform>) -> tensor<2x3x3x2xf32> -// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) <{Sout = [#tf_type.shape<1x2x2x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> +// CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[DQ:.+]] = "quantfork.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<2x3x3x2xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) +// CHECK-SAME: _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq { }" +// CHECK-SAME: (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> // CHECK: return %[[CALL]] : tensor<1x2x2x2xf32> // ----- -// Test that q/dq pair is inserted between constant and XlaCallModule op -// with quantizable trait and function name containing dot_general. +// Test that q/dq pair with per-tensor quantization parameter is inserted +// between constant and XlaCallModule op with empty `weight_only_ptq` method and +// function name containing dot_general. -func.func @qdq_for_dot_general_weight(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { +func.func @qdq_for_dot_general_weight_empty(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> %0 = "tf.XlaCallModule"(%arg0, %cst) { Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", - _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 @@ -42,16 +46,228 @@ func.func @qdq_for_dot_general_weight(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> return %0 : tensor<1x3xf32> } -// CHECK-LABEL: func.func @qdq_for_dot_general_weight +// CHECK-LABEL: func.func @qdq_for_dot_general_weight_empty // CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3xf32>}> : () -> tensor<2x3xf32> -// CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> -// CHECK: %[[DQ:.+]] = "quantfork.dcast"(%[[Q]]) : (tensor<2x3x!quant.uniform>) -> tensor<2x3xf32> -// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[DQ:.+]] = "quantfork.dcast"(%[[Q]]) : (tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<2x3xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) +// CHECK-SAME: _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq { }" +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> // CHECK: return %[[CALL]] : tensor<1x3xf32> // ----- +// Test that q/dq pair with per-tensor quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `per_tensor` and function name containing conv. + +func.func @qdq_for_conv_weight_per_tensor(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x2x2x2>], _entry_function = @composite_conv_fn, + _original_entry_function = "composite_conv_fn", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {per_tensor {}}}}", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> + return %0 : tensor<1x2x2x2xf32> +} + +// CHECK-LABEL: func.func @qdq_for_conv_weight_per_tensor +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> +// CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[DQ:.+]] = "quantfork.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<2x3x3x2xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) +// CHECK-SAME: _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {per_tensor {}}}}" +// CHECK-SAME: (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> +// CHECK: return %[[CALL]] : tensor<1x2x2x2xf32> + +// ----- + +// Test that q/dq pair with per-tensor quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `per_tensor` and function name containing dot_general. + +func.func @qdq_for_dot_general_weight_per_tensor(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {per_tensor {}}}}", _stablehlo_module_attrs = {}, + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} + +// CHECK-LABEL: func.func @qdq_for_dot_general_weight_per_tensor +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> +// CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3xf32>}> : () -> tensor<2x3xf32> +// CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[DQ:.+]] = "quantfork.dcast"(%[[Q]]) : (tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<2x3xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) +// CHECK-SAME: _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {per_tensor {}}}}" +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] : tensor<1x3xf32> + +// ----- + +// Test that q/dq pair with per-channel quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `quatized_type` without specified quantization dimension and function name +// containing conv. + +module attributes {tf_saved_model.semantics} { + func.func private @qdq_for_conv_weight_per_channel_default(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], version = 5 : i64, + _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}", + _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + device = "" + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } + + // CHECK: func.func private @qdq_for_conv_weight_per_channel_default(%[[ARG0:.+]]: tensor<1x3x4x3xf32>) + // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> + // CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> + // CHECK: %[[DQ:.+]] = "quantfork.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<2x3x3x2xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[DQ]]) + // CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + // CHECK: return %[[CALL]] + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } + // CHECK: func private @composite_conv_fn + // CHECK: %[[CONV:.+]] = stablehlo.convolution + // CHECK: return %[[CONV]] +} + +// ----- + +// Test that q/dq pair with per-channel quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `quatized_type` without specified quantization dimension and function name +// containing dot_general. + +module attributes {tf_saved_model.semantics} { + func.func private @qdq_for_dot_general_weight_per_channel_default(%arg0: tensor<4x3x6x5xf32>) -> tensor<4x3x6x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<4x3x5x2xf32>} : () -> tensor<4x3x5x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<4x3x6x2>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}", + _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + return %0 : tensor<4x3x6x2xf32> + } + // CHECK: func.func private @qdq_for_dot_general_weight_per_channel_default(%[[ARG0:.+]]: tensor<4x3x6x5xf32>) + // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<4x3x5x2xf32>}> : () -> tensor<4x3x5x2xf32> + // CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<4x3x5x2xf32>) -> tensor<4x3x5x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> + // CHECK: %[[DQ:.+]] = "quantfork.dcast"(%[[Q]]) : (tensor<4x3x5x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<4x3x5x2xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[DQ]]) + // CHECK-SAME: (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + // CHECK: return %[[CALL]] + + func.func private @composite_dot_general_fn(%arg0: tensor<4x3x6x5xf32>, %arg1: tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] : (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + return %0 : tensor<4x3x6x2xf32> + } + // CHECK: func private @composite_dot_general_fn + // CHECK: %[[DOT:.+]] = stablehlo.dot_general + // CHECK: return %[[DOT]] +} + +// ----- + +// Test that q/dq pair with per-channel quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `quatized_type` with specified quantization dimension and function name +// containing conv. + +module attributes {tf_saved_model.semantics} { + func.func private @qdq_for_conv_weight_per_channel(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], version = 5 : i64, + _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + device = "" + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } + + // CHECK: func.func private @qdq_for_conv_weight_per_channel(%[[ARG0:.+]]: tensor<1x3x4x3xf32>) + // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> + // CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> + // CHECK: %[[DQ:.+]] = "quantfork.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<2x3x3x2xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[DQ]]) + // CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + // CHECK: return %[[CALL]] + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } + // CHECK: func private @composite_conv_fn + // CHECK: %[[CONV:.+]] = stablehlo.convolution + // CHECK: return %[[CONV]] +} + +// ----- + +// Test that q/dq pair with per-channel quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `quatized_type` with specified quantization dimension and function name +// containing dot_general. + +module attributes {tf_saved_model.semantics} { + func.func private @qdq_for_dot_general_weight_per_channel(%arg0: tensor<4x3x6x5xf32>) -> tensor<4x3x6x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<4x3x5x2xf32>} : () -> tensor<4x3x5x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<4x3x6x2>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + return %0 : tensor<4x3x6x2xf32> + } + // CHECK: func.func private @qdq_for_dot_general_weight_per_channel(%[[ARG0:.+]]: tensor<4x3x6x5xf32>) + // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<4x3x5x2xf32>}> : () -> tensor<4x3x5x2xf32> + // CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<4x3x5x2xf32>) -> tensor<4x3x5x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> + // CHECK: %[[DQ:.+]] = "quantfork.dcast"(%[[Q]]) : (tensor<4x3x5x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<4x3x5x2xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[DQ]]) + // CHECK-SAME: (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + // CHECK: return %[[CALL]] + + func.func private @composite_dot_general_fn(%arg0: tensor<4x3x6x5xf32>, %arg1: tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] : (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + return %0 : tensor<4x3x6x2xf32> + } + // CHECK: func private @composite_dot_general_fn + // CHECK: %[[DOT:.+]] = stablehlo.dot_general + // CHECK: return %[[DOT]] +} + +// ----- + // Test that q/dq pair is not inserted between constant and XlaCallModule op // whose entry function name does not include conv nor dot_general. @@ -59,7 +275,7 @@ func.func @no_qdq_except_conv_and_dot_general(%arg0: tensor<2x3x2xi64>) -> tenso %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<3x4x2xf32>} : () -> tensor<3x4x2xf32> %0 = "tf.XlaCallModule"(%cst, %arg0) { Sout = [#tf_type.shape<1x3>], _entry_function = @composite_gather_fn, - _original_entry_function = "composite_gather_fn", + _original_entry_function = "composite_gather_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 @@ -81,7 +297,7 @@ func.func @no_qdq_for_non_weight_constant(%arg0: tensor<1x2xf32>, %arg1: tensor< %0 = "tf.XlaCallModule"(%arg0, %arg1, %cst) { Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_fn, _original_entry_function = "composite_dot_general_with_bias_fn", - _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 @@ -96,7 +312,7 @@ func.func @no_qdq_for_non_weight_constant(%arg0: tensor<1x2xf32>, %arg1: tensor< // ----- // Test that q/dq pair is not inserted between constant and XlaCallModule op -// without quantizable trait. +// without `weight_only_ptq` method. func.func @no_qdq_for_not_quantizable_call(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> @@ -116,6 +332,27 @@ func.func @no_qdq_for_not_quantizable_call(%arg0: tensor<1x2xf32>) -> tensor<1x3 // ----- +// Test that q/dq pair is not inserted between constant and XlaCallModule op +// with different method. + +func.func @no_qdq_for_not_quantizable_call(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], + disabled_checks = [], has_token_input_output = false, module = "", + platforms = [], _quantization_method = "static_range_ptq { }", version = 5 : i64 + } : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} + +// CHECK-LABEL: func.func @no_qdq_for_not_quantizable_call +// CHECK-NOT: quantfork.qcast +// CHECK-NOT: quantfork.dcast + +// ----- + // Test that q/dq pair is not inserted when constant has multiple users. func.func @no_qdq_for_multiple_users(%arg0: tensor<2x2xf32>) -> tensor<2x3xf32> attributes {tf._original_func_name = "main_0"} { @@ -123,7 +360,7 @@ func.func @no_qdq_for_multiple_users(%arg0: tensor<2x2xf32>) -> tensor<2x3xf32> %0 = "tf.XlaCallModule"(%arg0, %cst) { Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", - _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions.mlir index fa722c2fc71c88..eb4c2416024512 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions.mlir @@ -247,9 +247,9 @@ func.func @conv_with_relu_dynamic_fn(%arg0: tensor) -> tensor // CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) // CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[CONV]] +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[DYNAMIC_BROADCAST_IN_DIM:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF]] // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[CONV]], %[[DYNAMIC_BROADCAST_IN_DIM]] // CHECK: return %[[MAX]] : tensor @@ -293,9 +293,9 @@ func.func @dot_general_with_relu_dynamic_fn(%arg0: tensor) -> tenso // CHECK: } // CHECK-LABEL: private @composite_dot_general_with_relu_dynamic_fn_1 -// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 // CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[DOT_GENERAL]] +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[DYNAMIC_BROADCAST_IN_DIM:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF]] // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[DOT_GENERAL]], %[[DYNAMIC_BROADCAST_IN_DIM]] // CHECK: return %[[MAX]] : tensor @@ -342,9 +342,9 @@ func.func @conv_with_relu6_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> // CHECK: } // CHECK-LABEL: private @composite_conv_with_relu6_fn_1 -// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> -// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> // CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[CONV]], %[[CONST_1]] // CHECK: return %[[CLAMP]] : tensor<1x3x3x4xf32> // CHECK: } @@ -367,9 +367,9 @@ func.func @dot_general_with_relu6_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x6 // CHECK: } // CHECK-LABEL: private @composite_dot_general_with_relu6_fn_1 -// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> -// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> // CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[DOT_GENERAL]], %[[CONST_1]] // CHECK: return %[[CLAMP]] : tensor<1x1x64xf32> // CHECK: } @@ -392,9 +392,9 @@ func.func @conv_with_relu6_dynamic_fn(%arg0: tensor) -> tensor -// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> // CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[CONV]], %[[CONST_1]] // CHECK: return %[[CLAMP]] : tensor // CHECK: } @@ -417,9 +417,9 @@ func.func @dot_general_with_relu6_dynamic_fn(%arg0: tensor) -> tens // CHECK: } // CHECK-LABEL: private @composite_dot_general_with_relu6_fn_1 -// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> -// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> // CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[DOT_GENERAL]], %[[CONST_1]] // CHECK: return %[[CLAMP]] : tensor // CHECK: } @@ -444,8 +444,8 @@ func.func @dot_general_with_bias_same_shape_and_relu_fn(%arg0: tensor<1x1x167xf3 // CHECK: } // CHECK-LABEL: private @composite_dot_general_with_bias_same_shape_and_relu_fn_1 -// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %arg2 // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[CONST]] // CHECK: return %[[MAX]] : tensor<1x1x64xf32> @@ -472,9 +472,9 @@ func.func @conv_with_bias_and_relu_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x // CHECK: } // CHECK-LABEL: private @composite_conv_with_bias_and_relu_fn_1 -// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 // CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_IN_DIM]] // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[CONST]] // CHECK: return %[[MAX]] : tensor<1x3x3x4xf32> @@ -501,9 +501,9 @@ func.func @dot_general_with_bias_and_relu_fn(%arg0: tensor<1x1x167xf32>) -> tens // CHECK: } // CHECK-LABEL: private @composite_dot_general_with_bias_and_relu_fn_1 -// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 // CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[BROADCAST_IN_DIM]] // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[CONST]] // CHECK: return %[[MAX]] : tensor<1x1x64xf32> @@ -533,12 +533,12 @@ func.func @conv_with_bias_and_relu_dynamic_fn(%arg0: tensor) -> t // CHECK: } // CHECK-LABEL: private @composite_conv_with_bias_and_relu_dynamic_fn_1 -// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) // CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[CONV]] // CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]] // CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] // CHECK: %[[SHAPE_OF_1:.*]] = shape.shape_of %[[ADD]] +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_1:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF_1]] // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[DYNAMIC_BROADCAST_IN_DIM_1]] // CHECK: return %[[MAX]] : tensor @@ -591,12 +591,12 @@ func.func @dot_general_with_bias_and_relu_dynamic_fn(%arg0: tensor) // CHECK: } // CHECK-LABEL: private @composite_dot_general_with_bias_and_relu_dynamic_fn_1 -// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 // CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[DOT_GENERAL]] // CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]] // CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] // CHECK: %[[SHAPE_OF_1:.*]] = shape.shape_of %[[ADD]] +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_1:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF_1]] // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[DYNAMIC_BROADCAST_IN_DIM_1]] // CHECK: return %[[MAX]] : tensor @@ -623,10 +623,10 @@ func.func @dot_general_with_bias_same_shape_and_relu6_fn(%arg0: tensor<1x1x167xf // CHECK: } // CHECK-LABEL: private @composite_dot_general_with_bias_same_shape_and_relu6_fn_1 -// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> -// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> // CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> // CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %arg2 +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] // CHECK: return %[[CLAMP]] : tensor<1x1x64xf32> // CHECK: } @@ -653,11 +653,11 @@ func.func @conv_with_bias_and_relu6_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3 // CHECK: } // CHECK-LABEL: private @composite_conv_with_bias_and_relu6_fn_1 -// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> -// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> // CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 // CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> // CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_IN_DIM]] +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] // CHECK: return %[[CLAMP]] : tensor<1x3x3x4xf32> // CHECK: } @@ -684,11 +684,11 @@ func.func @dot_general_with_bias_and_relu6_fn(%arg0: tensor<1x1x167xf32>) -> ten // CHECK: } // CHECK-LABEL: private @composite_dot_general_with_bias_and_relu6_fn_1 -// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> -// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> // CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 // CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> // CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[BROADCAST_IN_DIM]] +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] // CHECK: return %[[CLAMP]] : tensor<1x1x64xf32> // CHECK: } @@ -716,12 +716,12 @@ func.func @conv_with_bias_and_relu6_dynamic_fn(%arg0: tensor) -> // CHECK: } // CHECK-LABEL: private @composite_conv_with_bias_and_relu6_dynamic_fn_1 -// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> -// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> // CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) // CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[CONV]] // CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]] +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> // CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] // CHECK: return %[[CLAMP]] : tensor // CHECK: } @@ -771,12 +771,12 @@ func.func @dot_general_with_bias_and_relu6_dynamic_fn(%arg0: tensor // CHECK: } // CHECK-LABEL: private @composite_dot_general_with_bias_and_relu6_dynamic_fn_1 -// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> -// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> // CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 // CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[DOT_GENERAL]] // CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]] +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> // CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] // CHECK: return %[[CLAMP]] : tensor // CHECK: } @@ -808,3 +808,48 @@ func.func @gather_fn() -> tensor<2x3x2x2xi32> { // CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%arg0, %arg1) // CHECK: return %[[GATHER]] : tensor<2x3x2x2xi32> // CHECK: } + +// ----- + +// Test that the name of composite functions are deterministic. There are 3 +// unsorted functions in this module and each function has 2 quantizable ops. +module { + func.func @conv_3_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%1, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + func.return %2: tensor<1x3x3x4xf32> + } + + func.func @conv_1_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%1, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + func.return %2: tensor<1x3x3x4xf32> + } + + func.func @conv_2_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%1, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + func.return %2: tensor<1x3x3x4xf32> + } +} + +// CHECK-LABEL: @conv_3_fn +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_6, _original_entry_function = "composite_conv_fn_6" +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_5, _original_entry_function = "composite_conv_fn_5" + +// CHECK-LABEL: @conv_1_fn +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_2, _original_entry_function = "composite_conv_fn_2" +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_1, _original_entry_function = "composite_conv_fn_1" + +// CHECK-LABEL: @conv_2_fn +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_4, _original_entry_function = "composite_conv_fn_4" +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_3, _original_entry_function = "composite_conv_fn_3" \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/post_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/post_quantize.mlir index ae2f57081e40f7..301a0661633425 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/post_quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/post_quantize.mlir @@ -37,7 +37,7 @@ func.func @remove_volatile_qdq_with_requantization(%arg0: tensor<3x2xf32>) -> te // CHECK-LABEL: @quantize_constant // CHECK-SAME: %[[ARG0:.*]]: tensor<1x3xf32> func.func @quantize_constant(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32> { - // CHECK-DAG: %[[QCST:.*]] = stablehlo.constant() {value = dense<-78> : tensor<3x2xi8>} : () -> tensor<3x2x!quant.uniform:f32, 5.000000e-03>> + // CHECK-DAG: %[[QCST:.*]] = stablehlo.constant() <{value = dense<-78> : tensor<3x2xi8>}> : () -> tensor<3x2x!quant.uniform:f32, 5.000000e-03>> // CHECK-DAG: %[[Q1:.*]] = stablehlo.uniform_quantize %[[ARG0]] // CHECK-NOT: "quantfork.qcast" // CHECK: %[[DOT:.*]] = stablehlo.dot %[[Q1]], %[[QCST]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_op_with_region.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_op_with_region.mlir index d94e1ca3787a3c..f5626e8b1506be 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_op_with_region.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_op_with_region.mlir @@ -14,12 +14,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q2]], %[[Q1]]) // CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[CALL]], %[[Q0]]) + // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + // CHECK-SAME: window_dimensions = array // CHECK: %[[ARG1:.*]]: tensor>, %[[ARG2:.*]]: tensor> // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> // CHECK: stablehlo.return %[[MAX]] : tensor> - // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> - // CHECK-SAME: window_dimensions = array - // CHECK-SAME: (tensor<2x3x1x3x!quant.uniform>, tensor>) -> tensor<2x3x1x3x!quant.uniform> + // CHECK: (tensor<2x3x1x3x!quant.uniform>, tensor>) -> tensor<2x3x1x3x!quant.uniform> // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[REDUCE]]) // CHECK: return %[[DQ]] @@ -70,12 +70,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[ARG0]]) // CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[Q1]], %[[Q0]]) + // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + // CHECK-SAME: window_dimensions = array // CHECK: %[[ARG1:.*]]: tensor>, %[[ARG2:.*]]: tensor> // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> // CHECK: stablehlo.return %[[MAX]] : tensor> - // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> - // CHECK-SAME: window_dimensions = array - // CHECK-SAME: (tensor<2x3x1x1024x!quant.uniform>, tensor>) -> tensor<2x3x1x1024x!quant.uniform> + // CHECK: (tensor<2x3x1x1024x!quant.uniform>, tensor>) -> tensor<2x3x1x1024x!quant.uniform> // CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[CST1]]) // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[REDUCE]], %[[Q2]]) @@ -132,12 +132,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[CALL]] // CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[RESHAPE]], %[[Q0]]) + // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64> + // CHECK-SAME: window_dimensions = array // CHECK: %[[ARG1:.*]]: tensor>, %[[ARG2:.*]]: tensor> // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> // CHECK: stablehlo.return %[[MAX]] : tensor> - // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64> - // CHECK-SAME: window_dimensions = array - // CHECK-SAME: (tensor<2x3x3x!quant.uniform>, tensor>) -> tensor<2x3x3x!quant.uniform> + // CHECK: (tensor<2x3x3x!quant.uniform>, tensor>) -> tensor<2x3x3x!quant.uniform> // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[REDUCE]]) // CHECK: return %[[DQ]] @@ -191,12 +191,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[ARG0]]) // CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[Q1]], %[[Q0]]) + // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64> + // CHECK-SAME: window_dimensions = array // CHECK: %[[ARG1:.*]]: tensor>, %[[ARG2:.*]]: tensor> // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> // CHECK: stablehlo.return %[[MAX]] : tensor> - // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64> - // CHECK-SAME: window_dimensions = array - // CHECK-SAME: (tensor<2x3x1024x!quant.uniform>, tensor>) -> tensor<2x3x1024x!quant.uniform> + // CHECK: (tensor<2x3x1024x!quant.uniform>, tensor>) -> tensor<2x3x1024x!quant.uniform> // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[REDUCE]] // CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[CST1]]) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_same_scale.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_same_scale.mlir index 25aab3044a3496..7a905dfbe58a9e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_same_scale.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_same_scale.mlir @@ -291,17 +291,17 @@ module attributes {tf_saved_model.semantics} { // CHECK-SAME: %[[ARG2:.*]]: tensor<2x3x2xi64> func.func private @composite_and_gather(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x2xf32>, %arg2: tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> { // CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[ARG0]]) {volatile} : (tensor<3x4x5xf32>) -> tensor<3x4x5x!quant.uniform> - // CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[ARG1]]) {volatile} : (tensor<3x5x2xf32>) -> tensor<3x5x2x!quant.uniform:f32, 6.000000e-03:13>> + // CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[ARG1]]) {volatile} : (tensor<3x5x2xf32>) -> tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>> // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) - // CHECK-SAME: (tensor<3x4x5x!quant.uniform>, tensor<3x5x2x!quant.uniform:f32, 6.000000e-03:13>>) -> tensor<3x4x2x!quant.uniform> + // CHECK-SAME: (tensor<3x4x5x!quant.uniform>, tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>>) -> tensor<3x4x2x!quant.uniform> // CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%[[CALL]], %[[ARG2]]) // CHECK-SAME: (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi64>) -> tensor<2x3x2x2x!quant.uniform> // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[GATHER]]) : (tensor<2x3x2x2x!quant.uniform>) -> tensor<2x3x2x2xf32> // CHECK: return %[[DQ]] %0 = "quantfork.qcast"(%arg0) {volatile} : (tensor<3x4x5xf32>) -> tensor<3x4x5x!quant.uniform> %1 = "quantfork.dcast"(%0) : (tensor<3x4x5x!quant.uniform>) -> tensor<3x4x5xf32> - %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor<3x5x2xf32>) -> tensor<3x5x2x!quant.uniform:f32, 6.000000e-03:13>> - %3 = "quantfork.dcast"(%2) : (tensor<3x5x2x!quant.uniform:f32, 6.000000e-03:13>>) -> tensor<3x5x2xf32> + %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor<3x5x2xf32>) -> tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>> + %3 = "quantfork.dcast"(%2) : (tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>>) -> tensor<3x5x2xf32> %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x5xf32>, tensor<3x5x2xf32>) -> tensor<3x4x2xf32> %5 = "quantfork.qcast"(%4) {volatile} : (tensor<3x4x2xf32>) -> tensor<3x4x2x!quant.uniform> %6 = "quantfork.dcast"(%5) : (tensor<3x4x2x!quant.uniform>) -> tensor<3x4x2xf32> @@ -321,10 +321,10 @@ module attributes {tf_saved_model.semantics} { // CHECK: quantized_dot_general_fn_1 // CHECK-SAME: %[[ARG2:.*]]: tensor<3x4x5x!quant.uniform> - // CHECK-SAME: %[[ARG3:.*]]: tensor<3x5x2x!quant.uniform:f32, 6.000000e-03:13>> + // CHECK-SAME: %[[ARG3:.*]]: tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>> func.func private @composite_dot_general_fn_1(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x2xf32>) -> tensor<3x4x2xf32> attributes {_from_xla_call_module} { // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] - // CHECK-SAME: (tensor<3x4x5x!quant.uniform>, tensor<3x5x2x!quant.uniform:f32, 6.000000e-03:13>>) -> tensor<3x4x2x!quant.uniform> + // CHECK-SAME: (tensor<3x4x5x!quant.uniform>, tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>>) -> tensor<3x4x2x!quant.uniform> // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<3x4x2x!quant.uniform>) -> tensor<3x4x2x!quant.uniform> // CHECK: return %[[Q3]] %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<3x4x5xf32>, tensor<3x5x2xf32>) -> tensor<3x4x2xf32> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir index 81e8b4bde5e13e..e152a90ce72c3a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir @@ -1,4 +1,4 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-quantize=enable-weight-only=true | FileCheck %s +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-quantize | FileCheck %s // Test that hybrid quantized dot_general is produced when q/dq pair only exists // for weight. @@ -6,8 +6,8 @@ module attributes {tf_saved_model.semantics} { func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { %cst = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> - %0 = "quantfork.qcast"(%cst) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> - %1 = "quantfork.dcast"(%0) : (tensor<2x3x!quant.uniform>) -> tensor<2x3xf32> + %0 = "quantfork.qcast"(%cst) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> + %1 = "quantfork.dcast"(%0) : (tensor<2x3x!quant.uniform>) -> tensor<2x3xf32> %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> return %2 : tensor<1x3xf32> } @@ -21,15 +21,15 @@ module attributes {tf_saved_model.semantics} { // CHECK-LABEL: quantize_dot_general_fn // CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> // CHECK: %[[CST:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> -// CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> +// CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> // CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[Q]]) -// CHECK-SAME: {_quantization_method = "weight_only_ptq { }"} : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK-SAME: {_quantization_method = "weight_only_ptq { }"} : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK: return %[[CALL]] // CHECK: quantized_dot_general_fn -// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] -// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK: return %[[DOT]] // ----- diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir index 09f002559b7830..a9d805412fdd2a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir @@ -17,14 +17,14 @@ module attributes {tf_saved_model.semantics} { // calls the quantized entry function. // CHECK: func.func private @quantize_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} -// CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> // CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>) -> tensor<1x3x!quant.uniform> // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> // CHECK-PER-TENSOR: func.func private @quantize_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} -// CHECK-PER-TENSOR: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> // CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> // CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> @@ -51,7 +51,7 @@ module attributes {tf_saved_model.semantics} { // ----- -// Tests that `stablehlo.dot_general` with `batching_dim` is not quantized. +// Tests that `stablehlo.dot_general` with `batching_dim` is quantized. module attributes {tf_saved_model.semantics} { func.func private @quantize_dot_general_batch_per_tensor_quantized_fn(%arg0: tensor<2x2x2xf32>) -> tensor<2x2x3xf32> attributes {tf._original_func_name = "main_0"} { @@ -62,9 +62,9 @@ module attributes {tf_saved_model.semantics} { return %2 : tensor<2x2x3xf32> } // CHECK: func.func private @quantize_dot_general_batch_per_tensor_quantized_fn(%[[ARG_0:.+]]: tensor<2x2x2xf32>) -> tensor<2x2x3xf32> attributes {tf._original_func_name = "main_0"} -// CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x2x3xi8>} : () -> tensor<2x2x3x!quant.uniform> +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x2x3xi8>}> : () -> tensor<2x2x3x!quant.uniform:f32, {{.*}}>> // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<2x2x2xf32>) -> tensor<2x2x2x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<2x2x2x!quant.uniform>, tensor<2x2x3x!quant.uniform) -> tensor<2x2x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<2x2x2x!quant.uniform>, tensor<2x2x3x!quant.uniform:f32, {{.*}}>) -> tensor<2x2x3x!quant.uniform> // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<2x2x3x!quant.uniform) -> tensor<2x2x3xf32> // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<2x2x3xf32> @@ -88,16 +88,16 @@ module attributes {tf_saved_model.semantics} { return %2 : tensor<1x3xf32> } // CHECK: func.func private @quantize_dot_general_with_bias_same_shape_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} -// CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> -// CHECK: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x3xi32>} : () -> tensor<1x3x!quant.uniform +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> +// CHECK: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x3xi32>}> : () -> tensor<1x3x!quant.uniform // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> // CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_same_shape_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> // CHECK-PER-TENSOR: func.func private @quantize_dot_general_with_bias_same_shape_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} -// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> -// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x3xi32>} : () -> tensor<1x3x!quant.uniform +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x3xi32>}> : () -> tensor<1x3x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> // CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_same_shape_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> @@ -137,16 +137,16 @@ module attributes {tf_saved_model.semantics} { return %2 : tensor } // CHECK: func.func private @quantize_dot_general_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<3xi32>} : () -> tensor<3x!quant.uniform +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<3xi32>}> : () -> tensor<3x!quant.uniform // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> // CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>, tensor<3x!quant.uniform) -> tensor // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor // CHECK-PER-TENSOR: func.func private @quantize_dot_general_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} -// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> -// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<3xi32>} : () -> tensor<3x!quant.uniform +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<3xi32>}> : () -> tensor<3x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> // CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<3x!quant.uniform) -> tensor // CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor @@ -219,7 +219,7 @@ module attributes {tf_saved_model.semantics} { // calls the quantized entry function. // CHECK: func.func private @quantize_conv_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} -// CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> // CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) // CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> @@ -227,7 +227,7 @@ module attributes {tf_saved_model.semantics} { // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> // CHECK-PER-TENSOR: func.func private @quantize_conv_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} -// CHECK-PER-TENSOR: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> // CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) // CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> @@ -286,7 +286,7 @@ func.func @quantize_conv_fn_per_tensor(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3 // calls the quantized entry function. // CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> // CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform> // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> @@ -338,8 +338,8 @@ module attributes {tf_saved_model.semantics} { return %2 : tensor<1x3x4x2xf32> } // CHECK: func.func private @quantize_conv_with_bias_1d_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<47978> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<47978> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> // CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_1d_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) // CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> @@ -347,8 +347,8 @@ module attributes {tf_saved_model.semantics} { // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> // CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_1d_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} -// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> -// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2xi32>} : () -> tensor<2x!quant.uniform +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> // CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_1d_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) // CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform @@ -407,8 +407,8 @@ module attributes {tf_saved_model.semantics} { return %2 : tensor<1x3x4x2xf32> } // CHECK: func.func private @quantize_conv_with_bias_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> // CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) // CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform @@ -416,8 +416,8 @@ module attributes {tf_saved_model.semantics} { // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> // CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} -// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> -// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> // CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) // CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform @@ -475,8 +475,8 @@ module attributes {tf_saved_model.semantics} { return %2 : tensor } // CHECK: func.func private @quantize_conv_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> // CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) // CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor @@ -484,8 +484,8 @@ module attributes {tf_saved_model.semantics} { // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor // CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} -// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> -// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> // CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) // CHECK-PER_TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor @@ -569,8 +569,8 @@ module attributes {tf_saved_model.semantics} { return %2 : tensor } // CHECK: func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> // CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) // CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> @@ -578,8 +578,8 @@ module attributes {tf_saved_model.semantics} { // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor // CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} -// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> -// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> // CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) // CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> @@ -666,8 +666,8 @@ module attributes {tf_saved_model.semantics} { return %2 : tensor } // CHECK: func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> // CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) // CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> @@ -675,8 +675,8 @@ module attributes {tf_saved_model.semantics} { // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor // CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} -// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> -// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> // CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) // CHECK-PER-TENSOR: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> @@ -810,8 +810,8 @@ module attributes {tf_saved_model.semantics} { %5 = "quantfork.stats"(%4) {layerStats = dense<[5.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> return %5 : tensor<1x3xf32> } -// CHECK: %[[CONST:.+]] = stablehlo.constant() {value = dense<127> : tensor<1x2xi8>} : () -> tensor<1x2x!quant.uniform> -// CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}>> +// CHECK: %[[CONST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<1x2xi8>}> : () -> tensor<1x2x!quant.uniform> +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}>> // CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> // CHECK: %[[CALL:.+]] = call @quantized_add_fn(%[[UNIFORM_QUANTIZE]], %[[CONST]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> // CHECK: %[[UNIFORM_DEQUANTIZE:.+]] = stablehlo.uniform_dequantize %[[CALL]] : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> @@ -837,3 +837,53 @@ module attributes {tf_saved_model.semantics} { // CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> // CHECK: return %[[UNIFORM_QUANTIZE]] : tensor<1x3x!quant.uniform> } + +// ----- + +// Tests that `stablehlo.add` is not quantized and emits error when the function +// does not include two ops. + +module attributes {tf_saved_model.semantics} { + func.func private @not_quantize_fn_when_not_singular(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<1.00000000e-1> : tensor<1x2xf32>} : () -> tensor<1x2xf32> + %0 = "quantfork.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x2>], _entry_function = @composite_add_fn, _original_entry_function = "composite_add_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> + // expected-error@+1 {{'stablehlo.uniform_dequantize' op operand #0 must be tensor of 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x2xf32>'}} + %2 = "quantfork.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + return %2 : tensor<1x2xf32> + } + + func.func private @composite_add_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.add %arg0, %arg1 : tensor<1x2xf32> + %1 = stablehlo.add %0, %arg1 : tensor<1x2xf32> + return %1 : tensor<1x2xf32> + } +} + +// ----- + +// Tests that `stablehlo.gather` without `static_range_ptq` is not quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @not_quantize_singular_op_without_static_range_ptq(%arg: tensor<3x4x2xf32>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<1> : tensor<2x3x2xi32>} : () -> tensor<2x3x2xi32> + %0 = "quantfork.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<3x4x2xf32>) -> tensor<3x4x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<2x3x2x2>], _entry_function = @composite_gather_fn, _original_entry_function = "composite_gather_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + // expected-error@+1 {{'stablehlo.uniform_dequantize' op operand #0 must be tensor of 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<2x3x2x2xf32>'}} + %2 = "quantfork.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<2x3x2x2xf32>) -> tensor<2x3x2x2xf32> + return %2 : tensor<2x3x2x2xf32> + } + + func.func private @composite_gather_fn(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> attributes {_from_xla_call_module} { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + return %0 : tensor<2x3x2x2xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir index b96cb15039d763..148e1330cfca34 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir @@ -1,11 +1,11 @@ // RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ -// RUN: -stablehlo-quantize-composite-functions=enable-weight-only=true | FileCheck --check-prefix=CHECK %s +// RUN: -stablehlo-quantize-composite-functions | FileCheck --check-prefix=CHECK %s -// Test that weight-only quantized dot_general op is produced when -// weight_only_ptq is provided. +// Test that per-tensor weight-only quantized dot_general op is produced when +// empty `weight_only_ptq` is provided. module attributes {tf_saved_model.semantics} { - func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + func.func private @quantize_dot_general_per_tensor(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> return %1 : tensor<1x3xf32> @@ -17,25 +17,25 @@ module attributes {tf_saved_model.semantics} { } } -// CHECK-LABEL: quantize_dot_general_fn +// CHECK-LABEL: quantize_dot_general_per_tensor // CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> -// CHECK: %[[CST:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform> -// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq { }"} : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq { }"} : (tensor<1x2xf32>, tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3xf32> // CHECK: return %[[CALL]] // CHECK: quantized_dot_general_fn -// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3xf32> // CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] -// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3xf32> // CHECK: return %[[DOT]] // ----- -// Test that hybrid quantized convolution op is produced when weight_only_ptq is -// provided. +// Test that per-tensor weight-only quantized convolution op is produced when +// empty `weight_only_ptq` is provided. module attributes {tf_saved_model.semantics} { - func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + func.func private @quantize_conv_per_tensor(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> return %1 : tensor<1x3x4x2xf32> @@ -47,14 +47,76 @@ module attributes {tf_saved_model.semantics} { } } -// CHECK-LABEL: quantize_conv_fn +// CHECK-LABEL: quantize_conv_per_tensor // CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x4x3xf32> -// CHECK: %[[CST:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform> -// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq { }"} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq { }"} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3x4x2xf32> // CHECK: return %[[CALL]] // CHECK: quantized_conv_fn -// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x3x4x3xf32>, %[[ARG2:.+]]: tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x3x4x3xf32>, %[[ARG2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3x4x2xf32> // CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG1]], %[[ARG2]]) -// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CONV]] + +// ----- + +// Test that per-channel weight-only quantized dot_general op is produced when +// `weight_only_ptq` with `dimension_specs` is provided. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_per_channel(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// CHECK-LABEL: quantize_dot_general_per_channel +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {0.0023622048182750312,0.0023622048182750312,0.0023622048182750312}>> +// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}"} +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform:f32:1, {0.0023622048182750312,0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_dot_general_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform:f32:1, {0.0023622048182750312,0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3xf32> +// CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform:f32:1, {0.0023622048182750312,0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3xf32> +// CHECK: return %[[DOT]] + +// ----- + +// Test that per-channel weight-only quantized convolution op is produced when +// `weight_only_ptq` with `dimension_specs` is provided. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_per_channel(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %1 : tensor<1x3x4x2xf32> + } + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +} + +// CHECK-LABEL: quantize_conv_per_channel +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x4x3xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> +// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}"} +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_conv_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x3x4x3xf32>, %[[ARG2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG1]], %[[ARG2]]) +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3x4x2xf32> // CHECK: return %[[CONV]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir index 02e1c5e9923915..ab55a8bc2989bb 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir @@ -22,23 +22,23 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p func.func @main(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x64xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> %1 = stablehlo.constant dense<1.000000e+03> : tensor<1x3xf32> - %2:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %2:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) %3 = "tf.XlaCallModule"(%2#0, %0, %1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> - %4:4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + %4:4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) %5 = stablehlo.constant dense<1.000000e+03> : tensor<3x64xf32> %6 = stablehlo.constant dense<1.000000e+03> : tensor<1x64xf32> - %7:4 = "tf.CustomAggregator"(%4#0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + %7:4 = "tf.CustomAggregator"(%4#0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) %8 = "tf.XlaCallModule"(%7#0, %5, %6) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x3xf32>, tensor<3x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> - %9:4 = "tf.CustomAggregator"(%6) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x64xf32>) -> (tensor<1x64xf32>, tensor, tensor, tensor<*xi64>) + %9:4 = "tf.CustomAggregator"(%6) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x64xf32>) -> (tensor<1x64xf32>, tensor, tensor, tensor<*xi64>) return %9#0 : tensor<1x64xf32> } // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}, module = "", platforms = ["CPU", "TPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_1 - // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> // CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable"} // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_0]]) // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}, module = "", platforms = ["CPU", "TPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0 - // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[CUSTOM_AGGREGATOR_1]]) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} + // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[CUSTOM_AGGREGATOR_1]]) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> // CHECK: %[[XLA_CALL_MODULE_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _tfl_quant_trait = "fully_quantizable"} // CHECK: %[[CUSTOM_AGGREGATOR_3:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_1:.*]]) // CHECK: return %[[CUSTOM_AGGREGATOR_3]] : tensor<1x64xf32> @@ -111,14 +111,14 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: @serving_default func.func @serving_default(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> - %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) %2 = "tf.XlaCallModule"(%1#0, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> - %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) return %3#0 : tensor<1x3xf32> } // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}, module = "", platforms = ["CPU", "TPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}} - // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: return %[[CUSTOM_AGGREGATOR_1]] @@ -143,14 +143,14 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK-LABEL: @random_name func.func @random_name(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> - %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) %2 = "tf.XlaCallModule"(%1#0, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> - %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) return %3#0 : tensor<1x3xf32> } // CHECK: %[[CONSTANT:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> - // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[XLA_CALL_MODULE_EXTRACTED_FROM_SUBGRAPH:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: return %[[CUSTOM_AGGREGATOR_1]] @@ -185,9 +185,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: @serving_default func.func @serving_default(%arg0: tensor<1024x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1024x3xf32> {tf_saved_model.index_path = ["output1"]}, tensor<1024x3xf32> {tf_saved_model.index_path = ["output2"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> - %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) %2 = "tf.XlaCallModule"(%1#0, %0) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> - %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) %4 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> %5 = stablehlo.add %3#0, %4 : tensor<1024x3xf32> %6 = stablehlo.multiply %3#0, %0 : tensor<1024x3xf32> @@ -195,7 +195,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p } // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 - // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: %[[SUBGRAPH_2:.*]]:2 = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>, #tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 @@ -235,16 +235,16 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p %0 = stablehlo.constant dense<1.000000e+03> : tensor<3x11xf32> // %1 is large enough that it won't be duplicated. %1 = stablehlo.constant dense<1.000000e+01> : tensor<3x11xf32> - %2:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor, tensor, tensor<*xi64>) + %2:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor, tensor, tensor<*xi64>) %3 = "tf.XlaCallModule"(%2#0, %0) {Sout = [#tf_type.shape<3x11>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x3xf32>, tensor<3x11xf32>) -> tensor<3x11xf32> - %4:4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<3x11xf32>) -> (tensor<3x11xf32>, tensor, tensor, tensor<*xi64>) + %4:4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<3x11xf32>) -> (tensor<3x11xf32>, tensor, tensor, tensor<*xi64>) %5 = stablehlo.add %4#0, %1 : tensor<3x11xf32> %6 = stablehlo.multiply %5, %1 : tensor<3x11xf32> return %6 : tensor<3x11xf32> } // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<3x11>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 - // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<3x11>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: %[[SUBGRAPH_2:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]]) <{Sout = [#tf_type.shape<3x11>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_0 @@ -293,14 +293,14 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p %4 = stablehlo.compare EQ, %3, %2, NOTYPE : (tensor<1024x3xf32>, tensor<1024x3xf32>) -> tensor<1024x3xi1> stablehlo.custom_call @shape_assertion(%4) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<1024x3xi1>) -> () %5 = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> - %6:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %6:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) %7 = "tf.XlaCallModule"(%6#0, %5) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> - %8:4 = "tf.CustomAggregator"(%7) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %8:4 = "tf.CustomAggregator"(%7) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) %9 = stablehlo.add %8#0, %0 : tensor<1024x3xf32> return %9 : tensor<1024x3xf32> } // CHECK: %[[SUBGRAPH_0:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 - // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[SUBGRAPH_0]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 @@ -339,16 +339,16 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p %2 = stablehlo.remainder %0, %1 : tensor<1024x3xf32> %3 = "tf.Identity"(%2) {device = ""} : (tensor<1024x3xf32>) -> tensor<1024x3xf32> %4 = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> - %5:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %5:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) %6 = "tf.XlaCallModule"(%5#0, %4) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> - %7:4 = "tf.CustomAggregator"(%6) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %7:4 = "tf.CustomAggregator"(%6) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) %8 = stablehlo.add %7#0, %0 : tensor<1024x3xf32> return %8 : tensor<1024x3xf32> } // CHECK: %[[SUBGRAPH_0:.*]]:2 = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>, #tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_2 // CHECK: %[[IDENTIFY:.*]] = "tf.Identity"(%[[SUBGRAPH_0]]#1) {device = ""} : (tensor<1024x3xf32>) -> tensor<1024x3xf32> // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 - // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: %[[SUBGRAPH_2:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_0]]#0) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 @@ -394,14 +394,14 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p %7 = stablehlo.compare EQ, %6, %2, NOTYPE : (tensor<1024x3xf32>, tensor<1024x3xf32>) -> tensor<1024x3xi1> stablehlo.custom_call @shape_assertion(%7) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<1024x3xi1>) -> () %8 = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> - %9:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %9:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) %10 = "tf.XlaCallModule"(%9#0, %8) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> - %11:4 = "tf.CustomAggregator"(%10) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %11:4 = "tf.CustomAggregator"(%10) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) %12 = stablehlo.add %11#0, %0 : tensor<1024x3xf32> return %12 : tensor<1024x3xf32> } // CHECK: %[[SUBGRAPH_0:.*]]:2 = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>, #tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 - // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[SUBGRAPH_0]]#1) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_0]]#0) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/pipelines/process_nchw_tensor.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/pipelines/process_nchw_tensor.mlir index 831131a4c64555..5e443526c650f1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/pipelines/process_nchw_tensor.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/pipelines/process_nchw_tensor.mlir @@ -100,8 +100,9 @@ func.func @nchw_conv_with_bias_add_max_pool(%arg0: tensor<1x2x5x5xf32>) -> tenso // CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[TRANSPOSE_0]], %[[WEIGHT_CONST]]) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = {{\[\[}}1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x5x5x2xf32>, tensor<3x3x2x4xf32>) -> tensor<1x5x5x4xf32> // CHECK: %[[ADD:.+]] = stablehlo.add %[[CONV]], %[[BIAS_CONST]] : tensor<1x5x5x4xf32> // CHECK: %[[REDUCE_WINDOW_MAX:.+]] = "stablehlo.reduce_window"(%[[ADD]], %[[INIT_VALUE_CONST:.+]]) +// CHECK: <{window_dimensions = array, window_strides = array}> // CHECK: stablehlo.maximum -// CHECK: {window_dimensions = array, window_strides = array} : (tensor<1x5x5x4xf32>, tensor) -> tensor<1x2x2x4xf32> +// CHECK: (tensor<1x5x5x4xf32>, tensor) -> tensor<1x2x2x4xf32> // CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose %[[REDUCE_WINDOW_MAX]], dims = [0, 3, 1, 2] : (tensor<1x2x2x4xf32>) -> tensor<1x4x2x2xf32> // CHECK: return %[[TRANSPOSE_1]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.cc index efda4282b2cbec..640f0ebc5c5061 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.cc @@ -29,7 +29,7 @@ bool IsLargeFloatType(Type type) { } Type ToBfloat16Type(Type type) { - if (auto shaped = type.dyn_cast()) { + if (auto shaped = mlir::dyn_cast(type)) { const Type elem = shaped.getElementType(); if (IsLargeFloatType(elem)) { return shaped.clone(BFloat16Type::get(type.getContext())); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/math_utils_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/math_utils_test.cc index fefdfbbb543123..1558a5478e604f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/math_utils_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/math_utils_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/utils/math_utils.h" #include +#include "mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir::quant::stablehlo { namespace { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.cc index 2f801565b93a1f..555e8af25b374f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.cc @@ -37,8 +37,8 @@ limitations under the License. namespace mlir::quant::tensorflow { bool IsTFQintType(const Type type) { - return type.isa(); + return mlir::isa(type); } Type GetIntTypeFromTFQint(const Type type) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc index 87d71438cf4e7c..e1fbe1917d9780 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/register_common_dialects.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -188,31 +189,31 @@ TEST(GetIntTypeFromTFQintTest, ChecksIntTypesFromTFQint) { auto type = GetIntTypeFromTFQint(TF::Qint8Type::get(context.get())); EXPECT_TRUE(llvm::isa(type)); - EXPECT_EQ(type.dyn_cast().getWidth(), 8); - EXPECT_FALSE(type.dyn_cast().isSigned()); - EXPECT_FALSE(type.dyn_cast().isUnsigned()); + EXPECT_EQ(mlir::dyn_cast(type).getWidth(), 8); + EXPECT_FALSE(mlir::dyn_cast(type).isSigned()); + EXPECT_FALSE(mlir::dyn_cast(type).isUnsigned()); type = GetIntTypeFromTFQint(TF::Qint16Type::get(context.get())); EXPECT_TRUE(llvm::isa(type)); - EXPECT_EQ(type.dyn_cast().getWidth(), 16); - EXPECT_FALSE(type.dyn_cast().isSigned()); - EXPECT_FALSE(type.dyn_cast().isUnsigned()); + EXPECT_EQ(mlir::dyn_cast(type).getWidth(), 16); + EXPECT_FALSE(mlir::dyn_cast(type).isSigned()); + EXPECT_FALSE(mlir::dyn_cast(type).isUnsigned()); type = GetIntTypeFromTFQint(TF::Qint32Type::get(context.get())); EXPECT_TRUE(llvm::isa(type)); - EXPECT_EQ(type.dyn_cast().getWidth(), 32); - EXPECT_FALSE(type.dyn_cast().isSigned()); - EXPECT_FALSE(type.dyn_cast().isUnsigned()); + EXPECT_EQ(mlir::dyn_cast(type).getWidth(), 32); + EXPECT_FALSE(mlir::dyn_cast(type).isSigned()); + EXPECT_FALSE(mlir::dyn_cast(type).isUnsigned()); type = GetIntTypeFromTFQint(TF::Quint8Type::get(context.get())); EXPECT_TRUE(llvm::isa(type)); - EXPECT_EQ(type.dyn_cast().getWidth(), 8); - EXPECT_TRUE(type.dyn_cast().isUnsigned()); + EXPECT_EQ(mlir::dyn_cast(type).getWidth(), 8); + EXPECT_TRUE(mlir::dyn_cast(type).isUnsigned()); type = GetIntTypeFromTFQint(TF::Quint16Type::get(context.get())); EXPECT_TRUE(llvm::isa(type)); - EXPECT_EQ(type.dyn_cast().getWidth(), 16); - EXPECT_TRUE(type.dyn_cast().isUnsigned()); + EXPECT_EQ(mlir::dyn_cast(type).getWidth(), 16); + EXPECT_TRUE(mlir::dyn_cast(type).isUnsigned()); // Non qint types are returned as is. EXPECT_EQ(GetIntTypeFromTFQint(IntegerType::get(type.getContext(), 32)), diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index 94dc1b1569620f..1762a67d7d3acf 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -348,7 +348,6 @@ cc_library( "passes/insert_quantized_functions.cc", "passes/insert_restore_op.cc", "passes/insert_save_op.cc", - "passes/issue_ids_of_custom_aggregation_ops.cc", "passes/lift_hashtable_ops_as_args.cc", "passes/lift_quantizable_spots_as_functions.cc", "passes/lift_quantizable_spots_as_functions.inc", @@ -504,6 +503,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD index 9ae8d6401afcd6..6e6ee260f48a2c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD @@ -12,13 +12,8 @@ load( "get_compatible_with_portable", "tf_kernel_library", "tf_py_strict_test", - "tf_python_pybind_extension", ) load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") -load( - "//tensorflow/core/platform:build_config_root.bzl", - "if_static", -) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -30,49 +25,6 @@ package( licenses = ["notice"], ) -# Directly linked to `custom_aggregator_op`. In general, one should avoid directly depending on -# this target to avoid the ODR violation. Depend on `calibrator_singleton` instead. -cc_library( - name = "calibrator_singleton_impl", - srcs = ["calibrator_singleton.cc"], - hdrs = ["calibrator_singleton.h"], - compatible_with = get_compatible_with_portable(), - deps = [ - ":calibration_statistics_collector_average_min_max", - ":calibration_statistics_collector_base", - ":calibration_statistics_collector_histogram", - ":calibration_statistics_collector_min_max", - ":calibration_statistics_proto_cc", - "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", - "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", - "//tensorflow/core:framework", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "calibrator_singleton", - hdrs = ["calibrator_singleton.h"], - compatible_with = get_compatible_with_portable(), - deps = if_static([":calibrator_singleton_impl"]) + [ - ":calibration_statistics_collector_base", - ":calibration_statistics_proto_cc", - "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", - "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", - "//tensorflow/core:framework", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - ], -) - cc_library( name = "calibration_statistics_collector_base", hdrs = ["calibration_statistics_collector_base.h"], @@ -181,20 +133,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "calibrator_singleton_test", - size = "small", - srcs = ["calibrator_singleton_test.cc"], - deps = [ - ":calibration_statistics_proto_cc", - ":calibrator_singleton_impl", - "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_googletest//:gtest_main", - ], -) - tf_kernel_library( name = "custom_aggregator_op", srcs = ["custom_aggregator_op.cc"], @@ -204,7 +142,6 @@ tf_kernel_library( "//tensorflow/compiler/mlir/quantization/tensorflow/python:__pkg__", ], deps = [ - ":calibrator_singleton_impl", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:calibration_parameters", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", @@ -238,7 +175,6 @@ tf_py_strict_test( deps = [ ":calibration_statistics_proto_py", ":gen_custom_aggregator_op_wrapper", - ":pywrap_calibration", "//tensorflow:tensorflow_py", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_py", "//tensorflow/python:pywrap_tensorflow", @@ -249,20 +185,6 @@ tf_py_strict_test( ], ) -tf_python_pybind_extension( - name = "pywrap_calibration", - srcs = ["pywrap_calibration.cc"], - pytype_srcs = ["pywrap_calibration.pyi"], - deps = [ - ":calibration_statistics_proto_cc", - ":calibrator_singleton", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@pybind11", - "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", - ], -) - tf_kernel_library( name = "calibration_statistics_saver_op", srcs = ["calibration_statistics_saver_op.cc"], @@ -276,6 +198,7 @@ tf_kernel_library( ":calibration_statistics_collector_base", ":calibration_statistics_collector_histogram", ":calibration_statistics_collector_min_max", + ":calibration_statistics_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", @@ -298,6 +221,7 @@ tf_cc_test( "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc index 8061ad3fe2d444..4b30fab0cc39fc 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.h" @@ -36,6 +37,7 @@ limitations under the License. #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/logging.h" #include "tsl/platform/file_system.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op_test.cc index 8335722cdea929..15cb07f4b93270 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/path.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc deleted file mode 100644 index 74575b761737a3..00000000000000 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc +++ /dev/null @@ -1,126 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" - -#include -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/base/const_init.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.h" -#include "tensorflow/core/framework/tensor.h" - -namespace tensorflow { -namespace calibrator { - -using ::stablehlo::quantization::CalibrationOptions; - -ABSL_CONST_INIT absl::Mutex CalibratorSingleton::lock_(absl::kConstInit); - -CalibratorSingleton& CalibratorSingleton::GetInstance() { - static CalibratorSingleton* calibrator = new CalibratorSingleton(); - return *calibrator; -} - -void CalibratorSingleton::ClearCollectedInformation() { - absl::MutexLock lock(&lock_); - - CalibratorSingleton& instance = GetInstance(); - instance.id_to_collector_.clear(); -} - -void CalibratorSingleton::ClearData(absl::string_view id) { - absl::MutexLock lock(&lock_); - - CalibratorSingleton& instance = GetInstance(); - - const std::string id_str{id}; - instance.id_to_collector_[id_str].reset(nullptr); -} - -void CalibratorSingleton::Report(absl::string_view id, const Tensor& min_tensor, - const Tensor& max_tensor, - const Tensor& histogram_tensor, - const CalibrationOptions& calib_opts) { - const float min_value = min_tensor.scalar()(); - const float max_value = max_tensor.scalar()(); - auto histogram_flat = histogram_tensor.flat(); - absl::Span histogram_data = - absl::MakeSpan(histogram_flat.data(), histogram_flat.size()); - Report(id, min_value, max_value, histogram_data, calib_opts); -} - -void CalibratorSingleton::Report(absl::string_view id, float min, float max, - absl::Span histogram, - const CalibrationOptions& calib_opts) { - absl::MutexLock lock(&lock_); - - CalibratorSingleton& instance = GetInstance(); - const std::string id_str{id}; - AssignIfNotExists(id_str, calib_opts); - instance.id_to_collector_[id_str]->Collect(min, max, histogram); -} - -std::optional CalibratorSingleton::GetStatistics( - absl::string_view id) { - absl::MutexLock lock(&lock_); - - CalibratorSingleton& instance = GetInstance(); - - const std::string id_str{id}; - - if (!instance.id_to_collector_[id_str]) { - return std::nullopt; - } - - return instance.id_to_collector_[id_str]->GetStatistics(); -} - -void CalibratorSingleton::AssignIfNotExists( - std::string id_str, const CalibrationOptions& calib_opts) { - CalibratorSingleton& instance = GetInstance(); - if (instance.id_to_collector_[id_str]) return; - - switch (calib_opts.calibration_method()) { - case CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX: - instance.id_to_collector_[id_str] = - std::make_unique(); - break; - case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE: - case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE: - case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC: - case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY: - instance.id_to_collector_[id_str] = - std::make_unique(); - break; - case CalibrationOptions::CALIBRATION_METHOD_MIN_MAX: - default: - instance.id_to_collector_[id_str] = - std::make_unique(); - } -} - -} // namespace calibrator -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h deleted file mode 100644 index 8a6aee81ee9cbd..00000000000000 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATOR_SINGLETON_H_ -#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATOR_SINGLETON_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" -#include "tensorflow/core/framework/tensor.h" - -namespace tensorflow { -namespace calibrator { - -using stablehlo::quantization::CalibrationOptions; - -// TODO: b/315084876 - Move to stablehlo quantizer directory. -class CalibratorSingleton { - public: - // Clears the collected information. - static void ClearCollectedInformation(); - - // Clears the collected data of the given node id. - static void ClearData(absl::string_view id); - - // Reports data to the singleton. Only calculates the required statistics - // based on CalibrationOptions. - static void Report(absl::string_view id, const Tensor& min_tensor, - const Tensor& max_tensor, const Tensor& histogram_tensor, - const CalibrationOptions& calib_opts); - - // Same as above but accepts primitive input types. - static void Report(absl::string_view id, float min, float max, - absl::Span histogram, - const CalibrationOptions& calib_opts); - - // Returns the calibration statistics of the given id. - static std::optional GetStatistics( - absl::string_view id); - - private: - static CalibratorSingleton& GetInstance(); - static absl::Mutex lock_; - static void AssignIfNotExists(std::string id_str, - const CalibrationOptions& calib_opts); - - absl::flat_hash_map> - id_to_collector_; - - CalibratorSingleton() = default; - ~CalibratorSingleton() = default; -}; - -} // namespace calibrator -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATOR_SINGLETON_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc deleted file mode 100644 index ca338b58c5909d..00000000000000 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc +++ /dev/null @@ -1,203 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" - -#include -#include -#include - -#include -#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace calibrator { -namespace { - -using ::stablehlo::quantization::CalibrationOptions; - -TEST(CalibratorSingletonTest, SimpleMinMax) { - CalibrationOptions calib_opts; - calib_opts.set_calibration_method( - CalibrationOptions::CALIBRATION_METHOD_MIN_MAX); - - CalibratorSingleton::Report(/*id=*/"1", /*min=*/1.0f, /*max=*/5.0f, - /*histogram=*/{}, - /*calib_opts=*/calib_opts); - std::optional statistics = - CalibratorSingleton::GetStatistics(/*id=*/"1"); - - EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().min_max_statistics().global_min(), 1.0f); - EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 5.0f); - - CalibratorSingleton::Report(/*id=*/"1", /*min=*/1.0f, /*max=*/10.0f, - /*histogram=*/{}, - /*calib_opts=*/calib_opts); - statistics = CalibratorSingleton::GetStatistics(/*id=*/"1"); - - EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().min_max_statistics().global_min(), 1.0f); - EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 10.0f); - - CalibratorSingleton::Report(/*id=*/"1", /*min=*/-5.0f, /*max=*/5.0f, - /*histogram=*/{}, - /*calib_opts=*/calib_opts); - statistics = CalibratorSingleton::GetStatistics(/*id=*/"1"); - - EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().min_max_statistics().global_min(), -5.0f); - EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 10.0f); -} - -TEST(CalibratorSingletonTest, DifferentSessions) { - CalibrationOptions calib_opts; - calib_opts.set_calibration_method( - CalibrationOptions::CALIBRATION_METHOD_MIN_MAX); - - CalibratorSingleton::Report(/*id=*/"2", /*min=*/1.0f, /*max=*/5.0f, - /*histogram=*/{}, - /*calib_opts=*/calib_opts); - std::optional statistics = - CalibratorSingleton::GetStatistics(/*id=*/"2"); - - EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().min_max_statistics().global_min(), 1.0f); - EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 5.0f); - - CalibratorSingleton::Report(/*id=*/"2", /*min=*/1.0f, /*max=*/10.0f, - /*histogram=*/{}, - /*calib_opts=*/calib_opts); - statistics = CalibratorSingleton::GetStatistics(/*id=*/"2"); - - EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().min_max_statistics().global_min(), 1.0f); - EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 10.0f); - - CalibratorSingleton::Report(/*id=*/"3", /*min=*/-5.0f, /*max=*/5.0f, - /*histogram=*/{}, - /*calib_opts=*/calib_opts); - statistics = CalibratorSingleton::GetStatistics(/*id=*/"3"); - - EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().min_max_statistics().global_min(), -5.0f); - EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 5.0f); -} - -TEST(CalibratorSingletonTest, ClearAndGetEmptyResult) { - std::vector> report_vec; - CalibrationOptions calib_opts; - calib_opts.set_calibration_method( - CalibrationOptions::CALIBRATION_METHOD_MIN_MAX); - - report_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - report_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 10.0f}); - - CalibratorSingleton::Report(/*id=*/"4", /*min=*/1.0f, /*max=*/5.0f, - /*histogram=*/{}, - /*calib_opts=*/calib_opts); - std::optional statistics = - CalibratorSingleton::GetStatistics(/*id=*/"4"); - - EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().min_max_statistics().global_min(), 1.0f); - EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 5.0f); - - CalibratorSingleton::ClearData(/*id=*/"4"); - statistics = CalibratorSingleton::GetStatistics(/*id=*/"4"); - - EXPECT_FALSE(statistics.has_value()); -} - -TEST(CalibratorSingletonTest, ClearDataAndGetResults) { - CalibrationOptions calib_opts; - calib_opts.set_calibration_method( - CalibrationOptions::CALIBRATION_METHOD_MIN_MAX); - - CalibratorSingleton::Report(/*id=*/"5", /*min=*/1.0f, /*max=*/5.0f, - /*histogram=*/{}, - /*calib_opts=*/calib_opts); - std::optional statistics = - CalibratorSingleton::GetStatistics(/*id=*/"5"); - - EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().min_max_statistics().global_min(), 1.0f); - EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 5.0f); - - CalibratorSingleton::Report(/*id=*/"6", /*min=*/1.0f, /*max=*/10.0f, - /*histogram=*/{}, - /*calib_opts=*/calib_opts); - statistics = CalibratorSingleton::GetStatistics(/*id=*/"6"); - - EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().min_max_statistics().global_min(), 1.0f); - EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 10.0f); - - CalibratorSingleton::ClearData(/*id=*/"5"); - statistics = CalibratorSingleton::GetStatistics(/*id=*/"5"); - - EXPECT_FALSE(statistics.has_value()); - - CalibratorSingleton::Report(/*id=*/"6", /*min=*/1.0f, /*max=*/10.0f, - /*histogram=*/{}, - /*calib_opts=*/calib_opts); - statistics = CalibratorSingleton::GetStatistics(/*id=*/"6"); - - EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().min_max_statistics().global_min(), 1.0f); - EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 10.0f); -} - -TEST(CalibratorSingletonTest, SimpleAverageMinMax) { - CalibrationOptions calib_opts; - calib_opts.set_calibration_method( - CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX); - - CalibratorSingleton::Report(/*id=*/"7", /*min=*/-10.0f, /*max=*/30.0f, - /*histogram=*/{}, - /*calib_opts=*/calib_opts); - std::optional statistics = - CalibratorSingleton::GetStatistics(/*id=*/"7"); - - EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().average_min_max_statistics().min_sum(), -10.0f); - EXPECT_EQ(statistics.value().average_min_max_statistics().max_sum(), 30.0f); - EXPECT_EQ(statistics.value().average_min_max_statistics().num_samples(), 1); - - CalibratorSingleton::Report(/*id=*/"7", /*min=*/-20.0f, /*max=*/60.0f, - /*histogram=*/{}, - /*calib_opts=*/calib_opts); - statistics = CalibratorSingleton::GetStatistics(/*id=*/"7"); - - EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().average_min_max_statistics().min_sum(), -30.0f); - EXPECT_EQ(statistics.value().average_min_max_statistics().max_sum(), 90.0f); - EXPECT_EQ(statistics.value().average_min_max_statistics().num_samples(), 2); - - CalibratorSingleton::Report(/*id=*/"7", /*min=*/-30.0f, /*max=*/90.0f, - /*histogram=*/{}, - /*calib_opts=*/calib_opts); - statistics = CalibratorSingleton::GetStatistics(/*id=*/"7"); - - EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().average_min_max_statistics().min_sum(), -60.0f); - EXPECT_EQ(statistics.value().average_min_max_statistics().max_sum(), 180.0f); - EXPECT_EQ(statistics.value().average_min_max_statistics().num_samples(), 3); -} - -} // namespace -} // namespace calibrator -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.cc index 66d932a44f6179..ea37ab7b2be9bf 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/status/status.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" @@ -37,7 +36,6 @@ using ::stablehlo::quantization::CalculateBinIndexSafe; using ::stablehlo::quantization::CalculateBinWidth; using ::stablehlo::quantization::CalculateLowerBound; using ::stablehlo::quantization::CalibrationOptions; -using ::stablehlo::quantization::GetNumBins; using CPUDevice = ::Eigen::ThreadPoolDevice; using CalibrationMethod = ::stablehlo::quantization::CalibrationOptions_CalibrationMethod; @@ -52,7 +50,7 @@ REGISTER_OP("CustomAggregator") .Output("histogram: int64") .Attr("id: string") .Attr("calibration_method: int = 0") - .Attr("initial_num_bins: int = 0") + .Attr("num_bins: int = 0") .Attr("min_percentile: float = 0.0") .Attr("max_percentile: float = 0.0") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { @@ -60,12 +58,9 @@ REGISTER_OP("CustomAggregator") c->set_output(1, c->Scalar()); c->set_output(2, c->Scalar()); - const tensorflow::AttrValue* calibration_method_attr; - TF_RETURN_IF_ERROR( - c->GetAttr("calibration_method", &calibration_method_attr)); - int32_t num_bins = GetNumBins( - static_cast(calibration_method_attr->i())); - c->set_output(3, c->MakeShape({num_bins})); + const tensorflow::AttrValue* num_bins_attr; + TF_RETURN_IF_ERROR(c->GetAttr("num_bins", &num_bins_attr)); + c->set_output(3, c->MakeShape({num_bins_attr->i()})); return absl::OkStatus(); }); @@ -77,13 +72,12 @@ class CustomAggregatorOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("id", &id_)); int calibration_method_value; - int initial_num_bins; + int num_bins; float min_percentile; float max_percentile; OP_REQUIRES_OK(context, context->GetAttr("calibration_method", &calibration_method_value)); - OP_REQUIRES_OK(context, - context->GetAttr("initial_num_bins", &initial_num_bins)); + OP_REQUIRES_OK(context, context->GetAttr("num_bins", &num_bins)); OP_REQUIRES_OK(context, context->GetAttr("min_percentile", &min_percentile)); OP_REQUIRES_OK(context, @@ -98,8 +92,7 @@ class CustomAggregatorOp : public OpKernel { absl::AbortedError("The calibration method must be specified.")); calib_opts_.set_calibration_method(calibration_method); - calib_opts_.mutable_calibration_parameters()->set_initial_num_bins( - initial_num_bins); + calib_opts_.mutable_calibration_parameters()->set_num_bins(num_bins); calib_opts_.mutable_calibration_parameters()->set_min_percentile( min_percentile); calib_opts_.mutable_calibration_parameters()->set_max_percentile( @@ -123,7 +116,7 @@ class CustomAggregatorOp : public OpKernel { context->template eigen_device()) = input_flat.maximum(); // Calculate histogram statistics. - int32_t num_bins = GetNumBins(calib_opts_.calibration_method()); + const int32_t num_bins = calib_opts_.calibration_parameters().num_bins(); Tensor* histogram_output = nullptr; OP_REQUIRES_OK(context, context->allocate_output("histogram", {num_bins}, &histogram_output)); @@ -133,11 +126,6 @@ class CustomAggregatorOp : public OpKernel { CalculateHistogramStatistics(context, input_tensor, min_value, max_value, num_bins, histogram_output); } - - // By passing calib_opts_ and input_tensor to CalibratorSingleton, - // CalibrationStatisticsCollector can calculate statistics for calibration. - calibrator::CalibratorSingleton::Report(id_, *min_output, *max_output, - *histogram_output, calib_opts_); } private: diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py index 5940803f470117..78bd79e43faf1c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py @@ -17,9 +17,7 @@ import tensorflow # pylint: disable=unused-import from tensorflow.compiler.mlir.quantization.stablehlo import quantization_config_pb2 as stablehlo_quant_config_pb2 -from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 as calib_stat_pb2 from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import custom_aggregator_op_wrapper -from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import pywrap_calibration from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -39,7 +37,6 @@ def setUp(self): def testBypassAndMinMax(self): with self.session(): - pywrap_calibration.clear_calibrator() input_tensor = array_ops.constant( [1.0, 2.0, 3.0, 4.0, 5.0], dtypes.float32 ) @@ -55,18 +52,8 @@ def testBypassAndMinMax(self): self.assertEqual(aggregator_output.max, 5.0) self.assertEmpty(aggregator_output.histogram) - statistics: calib_stat_pb2.CalibrationStatistics = ( - pywrap_calibration.get_statistics_from_calibrator('1') - ) - - min_val = statistics.min_max_statistics.global_min - max_val = statistics.min_max_statistics.global_max - - self.assertAllEqual((min_val, max_val), (1.0, 5.0)) - def testTwoIdentities(self): with self.session(): - pywrap_calibration.clear_calibrator() input_tensor1 = array_ops.constant( [1.0, 2.0, 3.0, 4.0, 5.0], dtypes.float32 ) @@ -97,80 +84,8 @@ def testTwoIdentities(self): self.assertEqual(aggregator2_output.max, -1.0) self.assertEmpty(aggregator2_output.histogram) - statistics: calib_stat_pb2 = ( - pywrap_calibration.get_statistics_from_calibrator('2') - ) - min_val = statistics.min_max_statistics.global_min - max_val = statistics.min_max_statistics.global_max - self.assertAllEqual((min_val, max_val), (1.0, 5.0)) - statistics: calib_stat_pb2 = ( - pywrap_calibration.get_statistics_from_calibrator('3') - ) - min_val = statistics.min_max_statistics.global_min - max_val = statistics.min_max_statistics.global_max - self.assertAllEqual((min_val, max_val), (-5.0, -1.0)) - - def testClearData(self): - with self.session(): - pywrap_calibration.clear_calibrator() - input_tensor1 = array_ops.constant( - [1.0, 2.0, 3.0, 4.0, 5.0], dtypes.float32 - ) - aggregator1 = custom_aggregator_op_wrapper.custom_aggregator( - input_tensor1, - '4', - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX, - ) - aggregator1_output = self.evaluate(aggregator1) - self.assertAllEqual(aggregator1_output.output, [1.0, 2.0, 3.0, 4.0, 5.0]) - self.assertEqual(aggregator1_output.min, 1.0) - self.assertEqual(aggregator1_output.max, 5.0) - self.assertEmpty(aggregator1_output.histogram) - - input_tensor2 = array_ops.constant( - [-1.0, -2.0, -3.0, -4.0, -5.0], dtypes.float32 - ) - aggregator2 = custom_aggregator_op_wrapper.custom_aggregator( - input_tensor2, - '5', - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX, - ) - aggregator2_output = self.evaluate(aggregator2) - self.assertAllEqual( - aggregator2_output.output, [-1.0, -2.0, -3.0, -4.0, -5.0] - ) - self.assertEqual(aggregator2_output.min, -5.0) - self.assertEqual(aggregator2_output.max, -1.0) - self.assertEmpty(aggregator2_output.histogram) - - statistics: calib_stat_pb2 = ( - pywrap_calibration.get_statistics_from_calibrator('4') - ) - min_val = statistics.min_max_statistics.global_min - max_val = statistics.min_max_statistics.global_max - self.assertAllEqual((min_val, max_val), (1.0, 5.0)) - - statistics: calib_stat_pb2 = ( - pywrap_calibration.get_statistics_from_calibrator('5') - ) - min_val = statistics.min_max_statistics.global_min - max_val = statistics.min_max_statistics.global_max - self.assertAllEqual((min_val, max_val), (-5.0, -1.0)) - - pywrap_calibration.clear_data_from_calibrator('4') - with self.assertRaises(ValueError): - pywrap_calibration.get_statistics_from_calibrator('4') - - statistics: calib_stat_pb2 = ( - pywrap_calibration.get_statistics_from_calibrator('5') - ) - min_val = statistics.min_max_statistics.global_min - max_val = statistics.min_max_statistics.global_max - self.assertAllEqual((min_val, max_val), (-5.0, -1.0)) - def testBypassAndAverageMinMax(self): with self.session(): - pywrap_calibration.clear_calibrator() input_tensor1 = array_ops.constant( [-50.0, -25.0, 0.0, 25.0, 50.0], dtypes.float32 ) @@ -204,19 +119,8 @@ def testBypassAndAverageMinMax(self): self.assertEqual(aggregator2_output.max, 100.0) self.assertEmpty(aggregator2_output.histogram) - statistics: calib_stat_pb2 = ( - pywrap_calibration.get_statistics_from_calibrator('6') - ) - - min_sum = statistics.average_min_max_statistics.min_sum - max_sum = statistics.average_min_max_statistics.max_sum - num_samples = statistics.average_min_max_statistics.num_samples - - self.assertAllEqual((min_sum, max_sum, num_samples), (-150.0, 150.0, 2)) - def testHistogramCalibration(self): with self.session(): - pywrap_calibration.clear_calibrator() input_tensor = array_ops.constant( [1.0, 1.0, 3.0, 4.0, 6.0], dtypes.float32 ) @@ -225,7 +129,7 @@ def testHistogramCalibration(self): input_tensor, id='7', calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, - initial_num_bins=256, + num_bins=512, ) aggregator_output = self.evaluate(aggregator) self.assertAllEqual(aggregator_output.output, [1.0, 1.0, 3.0, 4.0, 6.0]) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/pywrap_calibration.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/pywrap_calibration.cc deleted file mode 100644 index 8f7c4e30457a2e..00000000000000 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/pywrap_calibration.cc +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" - -namespace py = ::pybind11; - -namespace { - -using ::tensorflow::calibrator::CalibrationStatistics; -using ::tensorflow::calibrator::CalibratorSingleton; - -// Retrieves collected statistics of a `CustomAggregator` node from the -// singleton. `id` is the identifier of the `CustomAggregator`. -CalibrationStatistics GetStatisticsFromCalibrator(const absl::string_view id) { - std::optional statistics = - CalibratorSingleton::GetStatistics(id); - - if (!statistics.has_value()) { - throw py::value_error(absl::StrFormat( - "Calibrated data does not exist. Cannot find statistics." - "value for id: '%s'", - id)); - } - - return *statistics; -} - -} // namespace - -PYBIND11_MODULE(pywrap_calibration, m) { - // Allows type casting protobuf objects. - pybind11_protobuf::ImportNativeProtoCasters(); - - m.doc() = "Defines functions for interacting with CalibratorSingleton."; - - m.def( - // If the function signature changes, likely its corresponding .pyi type - // hinting should also change. - // LINT.IfChange - "clear_calibrator", - []() -> void - // LINT.ThenChange(pywrap_calibration.pyi:clear_calibrator) - { CalibratorSingleton::ClearCollectedInformation(); }, - R"pbdoc( - Clears the collected metrics from the calibrator. - )pbdoc"); - m.def( - // If the function signature changes, likely its corresponding .pyi type - // hinting should also change. - // LINT.IfChange - "clear_data_from_calibrator", - [](const absl::string_view id) -> void - // LINT.ThenChange(pywrap_calibration.pyi:clear_data_from_calibrator) - { CalibratorSingleton::ClearData(id); }, - R"pbdoc( - Clears the collected data of the given id from calibrator. - )pbdoc", - py::arg("id")); - m.def( - // If the function signature changes, likely its corresponding .pyi type - // hinting should also change. - // LINT.IfChange - "get_statistics_from_calibrator", - [](const absl::string_view id) -> CalibrationStatistics { - // LINT.ThenChange(pywrap_calibration.pyi:get_statistics_from_calibrator) - return GetStatisticsFromCalibrator(id); - }, - R"pbdoc( - Returns the proto CalibrationStatistics given id from calibrator. - )pbdoc", - py::arg("id")); -} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.cc index 565adebfe52300..64695d6719885d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.cc @@ -71,7 +71,7 @@ LogicalResult FoldOperation(OpBuilder& builder, Operation* op, bool IsOperationFoldable(Operation* op) { if (isa(op)) return true; - if (!op->getDialect()->getNamespace().equals("tf") || !TF::CanBeFolded(op)) { + if (op->getDialect()->getNamespace() != "tf" || !TF::CanBeFolded(op)) { return false; } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc index c2445456339fb9..60d2c07bdab8ea 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc @@ -93,8 +93,7 @@ SmallVector GetEntryFunctionInputs(func::FuncOp func_op) { func_op->getAttrOfType("tf.entry_function"); SmallVector inputs; - entry_function_attr.get("inputs") - .dyn_cast_or_null() + mlir::dyn_cast_or_null(entry_function_attr.get("inputs")) .strref() .split(inputs, /*Separator=*/","); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args_test.cc index 238b1bb8ef8955..d77859a67c9dca 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args_test.cc @@ -94,9 +94,11 @@ TEST_F(ConvertAssetArgsTest, ConvertsSingleAssetArg) { EXPECT_THAT(arg_attrs.get("tf_saved_model.bound_input"), IsNull()); const ArrayRef index_path_attrs = - arg_attrs.get("tf_saved_model.index_path").cast().getValue(); + mlir::cast(arg_attrs.get("tf_saved_model.index_path")) + .getValue(); EXPECT_THAT(index_path_attrs, SizeIs(1)); - StringAttr index_path = index_path_attrs[0].dyn_cast_or_null(); + StringAttr index_path = + mlir::dyn_cast_or_null(index_path_attrs[0]); EXPECT_THAT(index_path, NotNull()); EXPECT_THAT(index_path, Eq("arg_0:0")); } @@ -122,9 +124,11 @@ TEST_F(ConvertAssetArgsTest, NonBoundedArgsNotModified) { EXPECT_THAT(arg_attrs.get("tf_saved_model.bound_input"), IsNull()); const ArrayRef index_path_attrs = - arg_attrs.get("tf_saved_model.index_path").cast().getValue(); + mlir::cast(arg_attrs.get("tf_saved_model.index_path")) + .getValue(); EXPECT_THAT(index_path_attrs, SizeIs(1)); - StringAttr index_path = index_path_attrs[0].dyn_cast_or_null(); + StringAttr index_path = + mlir::dyn_cast_or_null(index_path_attrs[0]); EXPECT_THAT(index_path, NotNull()); EXPECT_THAT(index_path, Eq("arg_0:0")); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc index 7be369e7947ced..8ba632b66ae0f3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc @@ -65,9 +65,9 @@ bool QuantizationUnitLoc::classof(Attribute attr) { if (!llvm::isa(attr)) return false; auto callsite_loc = llvm::dyn_cast(attr); - if (!callsite_loc.getCaller().isa()) return false; + if (!mlir::isa(callsite_loc.getCaller())) return false; StringRef caller_name = - callsite_loc.getCaller().cast().getName().strref(); + mlir::cast(callsite_loc.getCaller()).getName().strref(); return caller_name.starts_with(kQuantizationUnitPrefix) && caller_name.ends_with(kQuantizationUnitSuffix); } @@ -75,8 +75,8 @@ bool QuantizationUnitLoc::classof(Attribute attr) { std::optional FindQuantizationUnitFromLoc(Location loc) { if (isa(loc)) { - Location caller = loc.cast().getCaller(); - StringRef caller_name = caller.cast().getName().strref(); + Location caller = mlir::cast(loc).getCaller(); + StringRef caller_name = mlir::cast(caller).getName().strref(); const size_t start_index = kQuantizationUnitPrefix.size(); const size_t end_index = caller_name.rfind(kQuantizationUnitSuffix); std::string serialized_proto = @@ -87,14 +87,15 @@ FindQuantizationUnitFromLoc(Location loc) { } } else if (isa(loc)) { // If the op is rewritten, FusedLoc can be created. - for (Location child_loc : loc.cast().getLocations()) { + for (Location child_loc : mlir::cast(loc).getLocations()) { std::optional found_unit = FindQuantizationUnitFromLoc(child_loc); if (found_unit.has_value()) return found_unit; } } else if (isa(loc)) { // If the graph is inlined, CallSiteLoc can be created. - return FindQuantizationUnitFromLoc(loc.cast().getCallee()); + return FindQuantizationUnitFromLoc( + mlir::cast(loc).getCallee()); } return std::nullopt; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc index 52ca3722a12bd5..9630b20b32d571 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc @@ -56,7 +56,7 @@ bool IsOpWithInt8TypeOperand(Operation* op) { } bool IsValueWithQuantizablePrecision(Value val) { - auto type = val.getType().dyn_cast(); + auto type = mlir::dyn_cast(val.getType()); if (!type) return false; // Supported original tensor data types. if (type.getElementType().isF32() || type.getElementType().isBF16()) @@ -82,7 +82,7 @@ std::unique_ptr GetTFOpQuantSpec(Operation* op) { auto spec = std::make_unique(); if (auto call_op = dyn_cast(op)) { StringRef function_name = - call_op.getFAttr().cast().getValue(); + mlir::cast(call_op.getFAttr()).getValue(); if (!function_name.starts_with("composite_")) { return spec; } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc index 723adde447e546..47beb9e0c2636f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc @@ -153,11 +153,10 @@ QuantizedType CalculateUniformQuantParams( DenseFPElementsAttr attr; if (!matchPattern(op->getResult(0), m_Constant(&attr))) return nullptr; - QuantizedType quant_type = + QuantizedType quant_type = mlir::dyn_cast( quant::GetUniformQuantizedTypeForWeight( attr, /*symmetric=*/kIsNarrowRange && kIsSigned, kBitWidth, kIsSigned, - kIsNarrowRange, /*is_legacy_float*/ false) - .template dyn_cast(); + kIsNarrowRange, /*is_legacy_float*/ false)); return quant_type; } @@ -172,16 +171,16 @@ std::optional AddUniformQuantizeOps(PatternRewriter& rewriter, } Type expressed_type = op.getResult().getType(); Type quantized_type = quant_type.castFromExpressedType(expressed_type); - ShapedType shaped_quantized_type = quantized_type.cast(); + ShapedType shaped_quantized_type = mlir::cast(quantized_type); DenseElementsAttr tensor_proto_attr = - Quantize(attr, shaped_quantized_type).dyn_cast(); + mlir::dyn_cast(Quantize(attr, shaped_quantized_type)); if (!tensor_proto_attr) { return nullptr; } - Type storage_type = shaped_quantized_type.getElementType() - .cast() - .getStorageType(); + Type storage_type = + mlir::cast(shaped_quantized_type.getElementType()) + .getStorageType(); ShapedType new_type = shaped_quantized_type.clone(storage_type); rewriter.setInsertionPointAfter(op); @@ -205,7 +204,7 @@ Operation* LogicsForUniformDequanization(PatternRewriter& rewriter, auto new_cast_op = rewriter.create(loc, create_unknown_input_shape, input_val); // TODO - b/278949920: Enable Per-Channel Quantization for XLA Opset - auto qtype = quant_type.dyn_cast(); + auto qtype = mlir::dyn_cast(quant_type); TensorType scale_type = RankedTensorType::get({}, rewriter.getF32Type()); Value scale_op = rewriter.create( loc, scale_type, @@ -253,7 +252,7 @@ std::optional ApplyUniformQuantization( std::optional dequantized_val = AddUniformDequantizeOps(rewriter, quant_type, quantized_val.value(), - op.getType().cast()); + mlir::cast(op.getType())); return dequantized_val; } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_quantization_unit_loc.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_quantization_unit_loc.cc index d390ac6d548e78..109fa943f9334b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_quantization_unit_loc.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_quantization_unit_loc.cc @@ -69,17 +69,17 @@ class AddQuantizationUnitLocPass // tensorflow/compiler/mlir/tensorflow/translate/import_model.cc for more // details. bool IsImportLocPattern(FusedLoc loc) { - ArrayRef locations = loc.cast().getLocations(); + ArrayRef locations = mlir::cast(loc).getLocations(); if (locations.size() < 2 || !isa(locations.front())) return false; StringRef op_type_with_suffix = - locations.front().cast().getName().strref(); + mlir::cast(locations.front()).getName().strref(); if (!op_type_with_suffix.ends_with(":")) return false; return absl::c_all_of(locations, [](Location loc) { return isa(loc) || (isa(loc) && - isa(loc.cast().getCallee())); + isa(mlir::cast(loc).getCallee())); }); } @@ -99,23 +99,23 @@ void FindQuantizationUnitsRecursively(Location loc, } }; - ArrayRef locations = loc.cast().getLocations(); - if (IsImportLocPattern(loc.cast())) { + ArrayRef locations = mlir::cast(loc).getLocations(); + if (IsImportLocPattern(mlir::cast(loc))) { QuantizationUnit new_unit; // Op type is a NameLoc with the ":" suffix. StringRef op_type_with_suffix = - locations.front().cast().getName().strref(); + mlir::cast(locations.front()).getName().strref(); StringRef op_type = op_type_with_suffix.substr(0, op_type_with_suffix.size() - 1); new_unit.set_op_type(op_type.str()); if (isa(locations.back())) { StringRef name_loc_id = - locations.back().cast().getName().strref(); + mlir::cast(locations.back()).getName().strref(); set_node_and_func_name(new_unit, name_loc_id); } else { - Location callee = locations.back().cast().getCallee(); - StringRef name_loc_id = callee.cast().getName().strref(); + Location callee = mlir::cast(locations.back()).getCallee(); + StringRef name_loc_id = mlir::cast(callee).getName().strref(); set_node_and_func_name(new_unit, name_loc_id); } units.push_back(new_unit); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc index e4229cb97bf45a..8c02ace87d8001 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" @@ -75,8 +76,8 @@ class ConvertCustomAggregationOpToQuantStats LogicalResult matchAndRewrite(TF::CustomAggregatorOp op, PatternRewriter &rewriter) const override { - FloatAttr min = op->getAttr("min").dyn_cast_or_null(); - FloatAttr max = op->getAttr("max").dyn_cast_or_null(); + FloatAttr min = mlir::dyn_cast_or_null(op->getAttr("min")); + FloatAttr max = mlir::dyn_cast_or_null(op->getAttr("max")); // When there are no min and max attributes, remove op. if (min == nullptr || max == nullptr) { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.cc index d23a0f8d3a7af2..c39492f0efe709 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.cc @@ -158,10 +158,8 @@ Value CreateEinsumOpFromXlaDotV2Op(OpBuilder& builder, const Location loc, xla::DotDimensionNumbers dot_dimension_numbers; dot_dimension_numbers.ParseFromString(dot_dimension_numbers_str.str()); SmallVector input_arguments = {lhs, rhs}; - const int lhs_rank = - lhs.getType().template cast().getShape().size(); - const int rhs_rank = - rhs.getType().template cast().getShape().size(); + const int lhs_rank = mlir::cast(lhs.getType()).getShape().size(); + const int rhs_rank = mlir::cast(rhs.getType()).getShape().size(); const std::string einsum_equation = CreateEinsumEquation(dot_dimension_numbers, lhs_rank, rhs_rank); @@ -218,7 +216,7 @@ RankedTensorType RestoreCollapsedDimensions( Type GetSliceOpOutputType(Type xla_gather_op_output_type, const absl::flat_hash_set& collapsed_dims) { if (auto ranked_output_type = - xla_gather_op_output_type.dyn_cast(); + mlir::dyn_cast(xla_gather_op_output_type); ranked_output_type) { return RestoreCollapsedDimensions(ranked_output_type, collapsed_dims); } @@ -228,9 +226,9 @@ Type GetSliceOpOutputType(Type xla_gather_op_output_type, // TODO (b/275225582): Supports Xla Gather op in general case. bool IsXlaGatherWithoutBatch(Value operand, Value start_indices) { - auto operand_type = operand.getType().dyn_cast_or_null(); + auto operand_type = mlir::dyn_cast_or_null(operand.getType()); auto start_indices_type = - start_indices.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(start_indices.getType()); if (start_indices_type == nullptr || operand_type == nullptr) return false; return start_indices_type.getShape().size() == 1; } @@ -245,7 +243,7 @@ Value CreateSliceAndReshapeOpFromXlaGatherOpWithoutBatch( // Construct full start_indices with given start_indices and // start_index_map. const ArrayRef operand_shape = - operand.getType().cast().getShape(); + mlir::cast(operand.getType()).getShape(); const int64_t operand_rank = operand_shape.size(); // Fills zeros if start_index is not given in start_indices. @@ -273,7 +271,7 @@ Value CreateSliceAndReshapeOpFromXlaGatherOpWithoutBatch( builder.create( loc, RankedTensorType::get( - start_indices.getType().template cast().getShape(), + mlir::cast(start_indices.getType()).getShape(), builder.getI64Type()), start_indices)); @@ -289,7 +287,7 @@ Value CreateSliceAndReshapeOpFromXlaGatherOpWithoutBatch( builder.create( loc, RankedTensorType::get( - slice_sizes.getType().template cast().getShape(), + mlir::cast(slice_sizes.getType()).getShape(), builder.getI64Type()), slice_sizes)); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc index b4cdcd8f771a21..b3fc6207842469 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc @@ -17,6 +17,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project @@ -92,7 +93,7 @@ class ReplaceTpuPartitionedCallOpWithPartitionedCallOp private: LogicalResult matchAndRewrite(TF::TPUPartitionedCallOp call_op, PatternRewriter& rewriter) const override { - auto f_attr = call_op.getFAttr().dyn_cast(); + auto f_attr = mlir::dyn_cast(call_op.getFAttr()); auto module_op = call_op->getParentOfType(); SymbolTable symbol_table(module_op); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc index 5ed89d89339571..c5d7ca8e47f6f9 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc @@ -14,10 +14,14 @@ limitations under the License. ==============================================================================*/ #include #include +#include +#include #include #include "absl/status/statusor.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" #include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -28,6 +32,7 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -53,6 +58,40 @@ using ::stablehlo::quantization::Method; constexpr StringRef kQuantTraitAttrName = "_tfl_quant_trait"; +// Whether the op is a call op to lifted composite function. +bool IsCallToQuantizableLiftedFunction(Operation *op) { + if (!op) return false; + if (auto xla_call_module_op = dyn_cast_or_null(op); + xla_call_module_op != nullptr) { + absl::StatusOr method = GetQuantizationMethod(xla_call_module_op); + if (method.ok() && method->has_static_range_ptq()) return true; + } + + TF::PartitionedCallOp call_op = dyn_cast_or_null(op); + return call_op && call_op->hasAttrOfType(kQuantTraitAttrName) && + call_op->getAttrOfType(kQuantTraitAttrName).getValue() == + llvm::StringRef( + QuantTraitValues[QuantizationTrait::FullyQuantizable]); +} + +// Returns the composite function name. +std::optional GetCompsiteFunctionName(Operation *op) { + if (!IsCallToQuantizableLiftedFunction(op)) return std::nullopt; + + if (auto xla_call_module_op = dyn_cast_or_null(op); + xla_call_module_op != nullptr) { + auto entry_function_attr = xla_call_module_op->getAttrOfType( + kOriginalStablehloEntryFunctionAttrName); + if (!entry_function_attr) return std::nullopt; + return entry_function_attr.getValue(); + } else { + TF::PartitionedCallOp call_op = dyn_cast_or_null(op); + const auto f_attr = call_op.getFAttr().dyn_cast(); + if (!f_attr) return std::nullopt; + return f_attr.getValue(); + } +} + class InsertCustomAggregationOpsPass : public PassWrapper> { @@ -145,7 +184,7 @@ class InsertCustomAggregationOpsPass CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE); auto calibration_parameters = CalibrationOptions::CalibrationParameters(); - calibration_parameters.set_initial_num_bins(256); + calibration_parameters.set_num_bins(512); calibration_parameters.set_min_percentile(0.001); calibration_parameters.set_max_percentile(99.999); calib_opts_.mutable_calibration_parameters()->CopyFrom( @@ -157,7 +196,7 @@ class InsertCustomAggregationOpsPass CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE); auto calibration_parameters = CalibrationOptions::CalibrationParameters(); - calibration_parameters.set_initial_num_bins(256); + calibration_parameters.set_num_bins(512); calib_opts_.mutable_calibration_parameters()->CopyFrom( calibration_parameters); break; @@ -167,7 +206,7 @@ class InsertCustomAggregationOpsPass CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY); auto calibration_parameters = CalibrationOptions::CalibrationParameters(); - calibration_parameters.set_initial_num_bins(256); + calibration_parameters.set_num_bins(512); calib_opts_.mutable_calibration_parameters()->CopyFrom( calibration_parameters); break; @@ -177,7 +216,7 @@ class InsertCustomAggregationOpsPass CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC); auto calibration_parameters = CalibrationOptions::CalibrationParameters(); - calibration_parameters.set_initial_num_bins(256); + calibration_parameters.set_num_bins(512); calib_opts_.mutable_calibration_parameters()->CopyFrom( calibration_parameters); break; @@ -204,17 +243,22 @@ class AddCustomAggregationOp : public RewritePattern { // The CustomAggregatorOp is only added after quantizable values. SmallVector quantizable_values; + SmallVector aggregator_ids; if (IsCallToQuantizableLiftedFunction(op)) { + std::optional composite_function_name = + GetCompsiteFunctionName(op); + if (!composite_function_name.has_value()) return failure(); + // Quantize inputs of quantizable composite functions. - for (Value input : op->getOperands()) { - Type element_type = getElementTypeOrSelf(input.getType()); + for (OpOperand &input : op->getOpOperands()) { + Type element_type = getElementTypeOrSelf(input.get().getType()); // Non-float cases won't be calibrated. if (!element_type.isF32()) { continue; } // Skip when there is any already existing CustomAggregatorOp found. - Operation *defining_op = input.getDefiningOp(); + Operation *defining_op = input.get().getDefiningOp(); if (dyn_cast_or_null(defining_op)) { continue; } @@ -225,41 +269,51 @@ class AddCustomAggregationOp : public RewritePattern { continue; } - quantizable_values.push_back(input); + quantizable_values.push_back(input.get()); + aggregator_ids.push_back( + (llvm::Twine(composite_function_name.value()) + "_arg_" + + llvm::Twine(input.getOperandNumber()) + "_calibration_method_" + + llvm::Twine(calib_opts_.calibration_method())) + .str()); } } else { // Quantize output of fully quantizable composite functions. for (Value input : op->getOperands()) { auto defining_op = input.getDefiningOp(); - if (!IsCallToQuantizableLiftedFunction(defining_op)) { - continue; - } + std::optional composite_function_name = + GetCompsiteFunctionName(defining_op); + if (!composite_function_name.has_value()) continue; // Do not add CustomAggregatorOp after Gather since it is a weight-only // quantizable op. if (auto call_op = dyn_cast_or_null(defining_op)) { StringRef function_name = - call_op.getFAttr().cast().getValue(); + mlir::cast(call_op.getFAttr()).getValue(); if (function_name.contains("gather")) continue; } quantizable_values.push_back(input); + // All composite functions have a single result at the moment. + aggregator_ids.push_back((llvm::Twine(composite_function_name.value()) + + "_calibration_method_" + + llvm::Twine(calib_opts_.calibration_method())) + .str()); } } if (quantizable_values.empty()) return failure(); - for (Value value : quantizable_values) { + int32_t effective_num_bins = GetNumBins(calib_opts_); + for (auto [value, aggregator_id] : + llvm::zip_equal(quantizable_values, aggregator_ids)) { // ID attribute will have empty value for now. SmallVector attributes{ - rewriter.getNamedAttr("id", rewriter.getStringAttr("")), + rewriter.getNamedAttr("id", rewriter.getStringAttr(aggregator_id)), rewriter.getNamedAttr( "calibration_method", rewriter.getI32IntegerAttr(calib_opts_.calibration_method())), - rewriter.getNamedAttr( - "initial_num_bins", - rewriter.getI32IntegerAttr( - calib_opts_.calibration_parameters().initial_num_bins())), + rewriter.getNamedAttr("num_bins", + rewriter.getI32IntegerAttr(effective_num_bins)), rewriter.getNamedAttr( "min_percentile", rewriter.getF32FloatAttr( @@ -270,12 +324,11 @@ class AddCustomAggregationOp : public RewritePattern { calib_opts_.calibration_parameters().max_percentile())), }; - int32_t num_bins = GetNumBins(calib_opts_.calibration_method()); SmallVector output_types{ value.getType(), RankedTensorType::get({}, rewriter.getF32Type()), RankedTensorType::get({}, rewriter.getF32Type()), - RankedTensorType::get({num_bins}, rewriter.getI64Type()), + RankedTensorType::get({effective_num_bins}, rewriter.getI64Type()), }; // Insert custom aggregation op between operand and operator. @@ -292,22 +345,6 @@ class AddCustomAggregationOp : public RewritePattern { private: CalibrationOptions calib_opts_; - - // Whether the op is a call op to lifted composite function. - bool IsCallToQuantizableLiftedFunction(Operation *op) const { - if (!op) return false; - if (auto xla_call_module_op = dyn_cast_or_null(op); - xla_call_module_op != nullptr) { - absl::StatusOr method = GetQuantizationMethod(xla_call_module_op); - if (method.ok() && method->has_static_range_ptq()) return true; - } - - TF::PartitionedCallOp call_op = dyn_cast_or_null(op); - return call_op && call_op->hasAttrOfType(kQuantTraitAttrName) && - call_op->getAttrOfType(kQuantTraitAttrName) - .getValue() - .equals(QuantTraitValues[QuantizationTrait::FullyQuantizable]); - } }; void InsertCustomAggregationOpsPass::runOnOperation() { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc index 682889917c112e..0f855088d17943 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc @@ -154,7 +154,7 @@ void GetUniqueInputOutputNodeNames(ModuleOp module_op, if (auto inputs_attr = tf_attrs.get("inputs")) { const std::string inputs_attr_str = - inputs_attr.cast().getValue().str(); + mlir::cast(inputs_attr).getValue().str(); std::vector fn_input_names = absl::StrSplit(inputs_attr_str, ',', absl::SkipEmpty()); @@ -174,7 +174,7 @@ void GetUniqueInputOutputNodeNames(ModuleOp module_op, if (auto outputs_attr = tf_attrs.get("outputs")) { const std::string outputs_attr_str = - outputs_attr.cast().getValue().str(); + mlir::cast(outputs_attr).getValue().str(); std::vector fn_output_names = absl::StrSplit(outputs_attr_str, ',', absl::SkipEmpty()); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/issue_ids_of_custom_aggregation_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/issue_ids_of_custom_aggregation_ops.cc deleted file mode 100644 index 1100a903f6e226..00000000000000 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/issue_ids_of_custom_aggregation_ops.cc +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include -#include - -#include "llvm/ADT/StringRef.h" -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" - -namespace mlir { -namespace quant { -namespace { - -class IssueIDsOfCustomAggregationOpsPass - : public PassWrapper> { - public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - IssueIDsOfCustomAggregationOpsPass) - - StringRef getArgument() const final { - // This is the argument used to refer to the pass in the textual format (on - // the commandline for example). - return "quant-issues-ids-of-custom-aggregation-ops"; - } - - StringRef getDescription() const final { - // This is a brief description of the pass. - return "Issue IDs of custom aggregation ops for the calibration procedure"; - } - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } - - private: - void runOnOperation() override; - - void issueIdToCustomAggregator(Operation* op); - - // Count of aggregator ops encountered; - int aggregator_count_; -}; - -static PassRegistration pass; - -void IssueIDsOfCustomAggregationOpsPass::issueIdToCustomAggregator( - Operation* op) { - // Return early when only aggregator operators are given. - if (!dyn_cast_or_null(op)) return; - - // Issue id based on the number of aggregators found. - OpBuilder builder(op); - op->setAttr("id", builder.getStringAttr(std::to_string(aggregator_count_))); - ++aggregator_count_; -} - -void IssueIDsOfCustomAggregationOpsPass::runOnOperation() { - ModuleOp module = getOperation(); - module.walk([&](Operation* op) { issueIdToCustomAggregator(op); }); -} - -} // namespace - -std::unique_ptr> -CreateIssueIDsOfCustomAggregationOpsPass() { - return std::make_unique(); -} - -} // namespace quant -} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_hashtable_ops_as_args.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_hashtable_ops_as_args.cc index f48d15dd81cbdb..18ee96bfe9422e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_hashtable_ops_as_args.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_hashtable_ops_as_args.cc @@ -70,14 +70,15 @@ bool IsHashTableOp(Operation* op) { // Checks if the function is the main or initializer function. bool IsMainOrInitializerFunction(ModuleOp module, func::FuncOp func) { - if (func.getSymName().equals(tensorflow::kImportModelDefaultGraphFuncName) || - func.getSymName().equals(kTfQuantSaveFuncName)) { + if (func.getSymName() == + llvm::StringRef(tensorflow::kImportModelDefaultGraphFuncName) || + func.getSymName() == kTfQuantSaveFuncName) { return true; } for (func::FuncOp init_func : tf_saved_model::GetInitializerFunctions(module)) { - if (func.getSymName().equals(init_func.getSymName())) { + if (func.getSymName() == init_func.getSymName()) { return true; } } @@ -118,7 +119,7 @@ bool IsResourceInitialized(ModuleOp module_op, Operation* hash_table) { tf_saved_model::GetInitializerFunctions(module_op)) { for (Operation& op : init_func_op.getBody().getOps()) { StringRef other_shared_name = GetSharedName(&op); - if (IsHashTableOp(&op) && other_shared_name.equals(shared_name)) { + if (IsHashTableOp(&op) && other_shared_name == shared_name) { return true; } } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc index 63fb3bd94005ee..672cd78b01de9c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc @@ -174,7 +174,7 @@ class CheckQuantizableOps LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, PatternRewriter& rewriter) const override { StringRef function_name = - call_op.getFAttr().cast().getValue(); + mlir::cast(call_op.getFAttr()).getValue(); if (!function_name.starts_with("composite_") || !call_op->hasAttr(kQuantTraitAttrName)) { return failure(); @@ -193,11 +193,10 @@ class CheckQuantizableOps } // Only the composite functions with f32 inputs are quantizable. - if (call_op.getResults().size() == 1 && !call_op->getResult(0) - .getType() - .cast() - .getElementType() - .isF32()) { + if (call_op.getResults().size() == 1 && + !mlir::cast(call_op->getResult(0).getType()) + .getElementType() + .isF32()) { check_status.Update(absl::InternalError( "Composite functions for quantization should be f32 type.")); } @@ -274,7 +273,7 @@ class CheckQuantizableOps // For BatchMatMul, the input must be ranked to determine the batch // dimensions. ShapedType shaped_type = - call_op->getOperand(0).getType().dyn_cast(); + mlir::dyn_cast(call_op->getOperand(0).getType()); if (!shaped_type || !shaped_type.hasRank()) { return absl::InternalError("The input of BatchMatMul must have rank."); } @@ -282,7 +281,8 @@ class CheckQuantizableOps // This op is guaranteed to be a constant as ODS checks IsConstTensor. // Check if the number of elements meets the requirement. int64_t num_elements = - call_op.getOperand(0).getType().cast().getNumElements(); + mlir::cast(call_op.getOperand(0).getType()) + .getNumElements(); if (num_elements < quant_options_.min_num_elements_for_weights()) { return absl::InternalError( "The params of Gather have fewer number of elements than " @@ -391,7 +391,9 @@ void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() { populateWithGenerated(patterns); patterns.add(ctx, quant_options_); FrozenRewritePatternSet frozen_patterns(std::move(patterns)); - for (auto func : module.getOps()) { + + // Iterate over the sorted list of functions to keep the order deterministic. + for (func::FuncOp func : GetSortedFunctions(module)) { if (failed(applyPatternsAndFoldGreedily(func, frozen_patterns))) { func.emitError() << "quant-lift-quantizable-spots-as-functions failed."; signalPassFailure(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc index 0acb2e56ea617e..a75bef5f842746 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc @@ -137,7 +137,8 @@ class CheckQuantizableOps // This op is guaranteed to be a constant as ODS checks IsConstTensor. // Check if the number of elements meets the requirement. int current_num_elements = - call_op.getOperand(idx).getType().cast().getNumElements(); + mlir::cast(call_op.getOperand(idx).getType()) + .getNumElements(); if (current_num_elements < min_num_elements_for_weights_) { call_op.emitRemark("Quantization is skipped for ") << call_op->getName().getStringRef().str() << " because it has " @@ -149,7 +150,7 @@ class CheckQuantizableOps } StringRef function_name = - call_op.getFAttr().cast().getValue(); + mlir::cast(call_op.getFAttr()).getValue(); if ((quantization_method_ == tensorflow::quantization::QuantizationMethod:: METHOD_DYNAMIC_RANGE_INT8) && (function_name.contains("batch_matmul") || diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc index f1f65a1a183371..fe196b9caa4452 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/TypeRange.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h" @@ -56,7 +57,6 @@ using ::mlir::tf_saved_model::kTfSavedModelInitializerInitType; using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; using ::mlir::tf_saved_model::kTfSavedModelInitializerTypeAttr; using ::mlir::tf_saved_model::SessionInitializerOp; -using ::tensorflow::kImportModelDefaultGraphFuncName; // Array of initializer functions' types. The corresponding initializer // functions should be merged in this order. This is because: @@ -153,7 +153,7 @@ LogicalResult ValidateInitFunc(func::FuncOp init_func_op) { FetchOp fetch_op = graph_op.GetFetch(); for (const Value fetch : fetch_op.getFetches()) { - if (!fetch.getType().isa()) { + if (!mlir::isa(fetch.getType())) { fetch_op.emitError(absl::StrFormat( "Validation failed for the initializer function: %s. " "All initializer function's fetches should be " diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_save_function_ops_to_main.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_save_function_ops_to_main.cc index e092352dc52c29..6f42c9fcaba7c5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_save_function_ops_to_main.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_save_function_ops_to_main.cc @@ -143,7 +143,7 @@ BlockArgument GetFilePrefixArg(func::FuncOp main_func_op) { auto index_path_attr = main_func_op.getArgAttrOfType(i, kTfSavedModelIndexPathAttr); if (index_path_attr && !index_path_attr.empty() && - index_path_attr[0].cast() == kTfFilePrefix) { + mlir::cast(index_path_attr[0]) == kTfFilePrefix) { return main_func_op.getArgument(i); } } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h index 5ea5a058cc94d3..9a0084ef38f412 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h @@ -63,10 +63,6 @@ CreateLiftQuantizableSpotsAsFunctionsDRQPass( std::unique_ptr> CreateConvertCustomAggregationOpToQuantStatsPass(); -// Issues IDs of custom aggregation ops for preparing the calibration procedure. -std::unique_ptr> -CreateIssueIDsOfCustomAggregationOpsPass(); - // Inserts quantized function library. std::unique_ptr> CreateInsertQuantizedFunctionsPass( tensorflow::quantization::QuantizationMethod::PresetMethod diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc index 38075bb67b7010..a87245345f6987 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc @@ -98,8 +98,8 @@ class PrepareLiftingPass // indices in `val2`. bool HasEqualElementSize(Value val1, Value val2, ArrayRef val1_indices, ArrayRef val2_indices) { - ShapedType val1_shape = val1.getType().cast(); - ShapedType val2_shape = val2.getType().cast(); + ShapedType val1_shape = mlir::cast(val1.getType()); + ShapedType val2_shape = mlir::cast(val2.getType()); if (!val1_shape.hasRank() || !val2_shape.hasRank()) return false; int val1_result = 1; @@ -134,7 +134,7 @@ bool ReshapableTo1DTensor(ShapedType rhs_shape) { } Value ReshapeTo1DTensor(OpBuilder& builder, Location loc, Value value) { - auto shape = value.getType().cast(); + auto shape = mlir::cast(value.getType()); if (shape.getRank() != 1) { SmallVector new_shape; new_shape.push_back(shape.getNumElements()); @@ -157,8 +157,8 @@ LogicalResult MatchSupportedAffineOp(Operation* op, Value& binding_output, bool is_supported_affine_op = false; if (llvm::isa(op)) { if (const auto data_format = op->getAttrOfType("data_format")) { - is_supported_affine_op = data_format.getValue().equals("NHWC") || - data_format.getValue().equals("NDHWC"); + is_supported_affine_op = + data_format.getValue() == "NHWC" || data_format.getValue() == "NDHWC"; } } else if (llvm::isa(op)) { if (const auto adj_y = op->getAttrOfType("adj_y")) { @@ -182,7 +182,7 @@ LogicalResult MatchSupportedAffineOp(Operation* op, Value& binding_output, // Makes the 1D value broadcastable with the `rhs_shape`. Value MakeOneDimValueBroadcastable(OpBuilder& builder, Location loc, Value value, ShapedType rhs_shape) { - ShapedType value_shape = value.getType().dyn_cast_or_null(); + ShapedType value_shape = mlir::dyn_cast_or_null(value.getType()); if (!value_shape || value_shape.getRank() != 1 || !value_shape.hasStaticShape() || !rhs_shape.hasStaticShape()) { return {}; @@ -211,7 +211,8 @@ bool CanBeSymmetricallyQuantized(Value weight) { auto dq_op = weight.getDefiningOp(); if (!dq_op) return true; - auto qtype = dq_op.getArg().getType().cast().getElementType(); + auto qtype = + mlir::cast(dq_op.getArg().getType()).getElementType(); if (auto uniform_type = llvm::dyn_cast_or_null(qtype)) { return uniform_type.getZeroPoint() == 0; } else if (auto per_axis_type = @@ -252,12 +253,12 @@ Value MultiplyFakeQuantValue(OpBuilder& builder, Location loc, Value value, Value float_value = q_op.getArg(); Value new_value = builder.create(loc, float_value, multiplier); - auto new_value_type = new_value.getType().cast(); + auto new_value_type = mlir::cast(new_value.getType()); // Get multiplier value in double. DenseFPElementsAttr multiplier_attr; if (!matchPattern(multiplier, m_Constant(&multiplier_attr)) || - multiplier_attr.getType().cast().getRank() > 1) { + mlir::cast(multiplier_attr.getType()).getRank() > 1) { return {}; } std::vector multiplier_values; @@ -268,7 +269,7 @@ Value MultiplyFakeQuantValue(OpBuilder& builder, Location loc, Value value, // Multiply the quantization parameters by the multiplier. QuantizedType new_qtype; - auto element_type = q_op.getType().cast().getElementType(); + auto element_type = mlir::cast(q_op.getType()).getElementType(); if (auto uniform_type = llvm::dyn_cast(element_type)) { if (multiplier_attr.isSplat()) { double new_scale = multiplier_array.front() * uniform_type.getScale(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc index fe38ed8dc0f634..cad8c1686eb67b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -171,8 +172,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { bool need_to_set_input_nodes_quantization_params = false; for (const BlockArgument arg : func.getArguments()) { - auto shaped = arg.getType().dyn_cast(); - if (shaped && shaped.getElementType().isa() && + auto shaped = mlir::dyn_cast(arg.getType()); + if (shaped && mlir::isa(shaped.getElementType()) && !has_quantize_op(arg)) { need_to_set_input_nodes_quantization_params = true; break; @@ -197,8 +198,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { auto add_quantize_op = [&](Location loc, Type input_type, Block* block, Block::iterator insertion_point, Value arg, int i) { - if (auto shaped = input_type.dyn_cast()) { - if (shaped.getElementType().isa()) { + if (auto shaped = mlir::dyn_cast(input_type)) { + if (mlir::isa(shaped.getElementType())) { // If there are existing quantize ops, they are from training and we // should respect them. if (has_quantize_op(arg)) { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc index 71587390580406..b2c0ceb205ca99 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" @@ -142,7 +143,7 @@ class PrepareDRQQuantizableOp : public OpRewritePattern { bool getQuantizableOps(arith::ConstantOp op, QuantizationUnits& quantizable_ops) const { // Non-float tensors do not need quantization. - auto type = op.getType().dyn_cast(); + auto type = mlir::dyn_cast(op.getType()); if (!type || !type.getElementType().isF32()) return false; Value value = op.getResult(); @@ -183,23 +184,23 @@ class PrepareDRQQuantizableOp : public OpRewritePattern { if (attr.size() < quant_specs_.minimum_elements_for_weights) { op->emitRemark("Quantization is skipped for ") << quantized_op->getName().getStringRef().str() << " because it has " - << attr.dyn_cast().size() + << mlir::dyn_cast(attr).size() << " elements which is fewer than the threshold(" << quant_specs_.minimum_elements_for_weights << " elements)."; return false; } if (is_per_channel_quantization) { - quant_type = quant::GetUniformQuantizedPerAxisTypeForWeight( - attr, quant_dim, - /*symmetric=*/true, bit_width, is_signed, - is_narrow_range, is_legacy_float) - .template dyn_cast(); + quant_type = mlir::dyn_cast( + quant::GetUniformQuantizedPerAxisTypeForWeight( + attr, quant_dim, + /*symmetric=*/true, bit_width, is_signed, is_narrow_range, + is_legacy_float)); } else { - quant_type = quant::GetUniformQuantizedTypeForWeight( - attr, is_narrow_range && is_signed, bit_width, is_signed, - is_narrow_range, is_legacy_float) - .template dyn_cast(); + quant_type = mlir::dyn_cast( + quant::GetUniformQuantizedTypeForWeight( + attr, is_narrow_range && is_signed, bit_width, is_signed, + is_narrow_range, is_legacy_float)); } return insertQDQ(rewriter, op, quant_type, quant_op); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc index 3f54fe580fe1c4..08b2faadacd3d5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc @@ -202,7 +202,7 @@ class PreprocessConstantOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::PartitionedCallOp op, PatternRewriter& rewriter) const override { - const auto f_attr = op.getFAttr().dyn_cast(); + const auto f_attr = mlir::dyn_cast(op.getFAttr()); // Non-quantizable op if (!op->hasAttr(kQuantTraitAttrName)) return failure(); StringRef function_name = f_attr.getValue(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc index 8570652b4019e7..0d2edd5bacd6c1 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc @@ -100,7 +100,7 @@ class PropagateDequantizeOpIfAllowed LogicalResult matchAndRewrite(TF::PartitionedCallOp op, PatternRewriter& rewriter) const override { - const auto f_attr = op.getFAttr().dyn_cast(); + const auto f_attr = mlir::dyn_cast(op.getFAttr()); StringRef function_name = f_attr.getValue(); if (!function_name.starts_with(kDequantizeFunctionName)) return failure(); @@ -127,7 +127,8 @@ class PropagateDequantizeOpIfAllowed auto original_result_type = user_op->getResult(0).getType(); auto new_user_op_type = CloneTypeWithNewElementType( original_result_type, - op_before_dequantize.getType().cast().getElementType()); + mlir::cast(op_before_dequantize.getType()) + .getElementType()); createNewDequantizeOp(rewriter, op, user_op, user_idx, new_user_op_type); } else { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc index 0b3c89c56f60bb..50409709d44854 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc @@ -213,11 +213,11 @@ LogicalResult CreateQuantizationParams(QuantizedType elem_type, Location loc, if (!elem_type) { return failure(); } - if (auto qtype = elem_type.dyn_cast()) { + if (auto qtype = mlir::dyn_cast(elem_type)) { return CreateUniformQuantizedTypeParams(qtype, loc, rewriter, scale, zero_point); - } else if (auto qtype = - elem_type.dyn_cast()) { + } else if (auto qtype = mlir::dyn_cast( + elem_type)) { return CreateUniformQuantizedPerAxisTypeParams(qtype, loc, rewriter, scale, zero_point); } @@ -235,7 +235,7 @@ ShapedType ConvertIntToQint(ShapedType input_type, MLIRContext* ctx) { if (ele_type.isIntOrFloat()) { bit_width = ele_type.getIntOrFloatBitWidth(); is_signed = ele_type.isSignlessIntOrFloat() || ele_type.isSignedInteger(); - } else if (QuantizedType qtype = ele_type.dyn_cast()) { + } else if (QuantizedType qtype = mlir::dyn_cast(ele_type)) { bit_width = qtype.getStorageTypeIntegralWidth(); is_signed = qtype.isSigned(); } else { @@ -275,8 +275,9 @@ class ReplaceQuantizePattern LogicalResult matchAndRewrite(quantfork::QuantizeCastOp q_op, PatternRewriter& rewriter) const override { - auto output_type = q_op.getType().cast(); - auto elem_type = output_type.getElementType().dyn_cast(); + auto output_type = mlir::cast(q_op.getType()); + auto elem_type = + mlir::dyn_cast(output_type.getElementType()); const Location loc = q_op->getLoc(); Value scale, zero_point; @@ -289,7 +290,7 @@ class ReplaceQuantizePattern if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { ShapedType new_output_type = ConvertIntToQint( - output_type.cast(), rewriter.getContext()); + mlir::cast(output_type), rewriter.getContext()); if (!new_output_type) { q_op->emitError( "Failed to convert the type to the corresponding qtype."); @@ -327,8 +328,8 @@ class ReplaceDequantizePattern LogicalResult matchAndRewrite(quantfork::DequantizeCastOp dq_op, PatternRewriter& rewriter) const override { - auto input_type = dq_op.getArg().getType().cast(); - auto elem_type = input_type.getElementType().dyn_cast(); + auto input_type = mlir::cast(dq_op.getArg().getType()); + auto elem_type = mlir::dyn_cast(input_type.getElementType()); const Location loc = dq_op->getLoc(); Value scale, zero_point; @@ -340,13 +341,13 @@ class ReplaceDequantizePattern TensorType output_type = input_type.clone(elem_type.getStorageType()); if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { ShapedType new_output_type = ConvertIntToQint( - output_type.cast(), rewriter.getContext()); + mlir::cast(output_type), rewriter.getContext()); if (!new_output_type) { dq_op->emitError( "Failed to convert the type to the corresponding qtype."); return failure(); } - output_type = new_output_type.cast(); + output_type = mlir::cast(new_output_type); } auto scast_op = rewriter.create(loc, output_type, @@ -376,8 +377,8 @@ bool IsQuantizedCallforDynamicRange(TF::PartitionedCallOp call_op) { return false; } else if (cur_op) { // Check if the QuantizeCastOp has element type of quantized type. - if (!getElementTypeOrSelf(cur_op.getResult().getType()) - .isa()) { + if (!mlir::isa( + getElementTypeOrSelf(cur_op.getResult().getType()))) { return false; } // Satisfies the input condition. @@ -385,8 +386,8 @@ bool IsQuantizedCallforDynamicRange(TF::PartitionedCallOp call_op) { } } for (Value output : call_op.getOutput()) { - if (auto type = output.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (auto type = mlir::dyn_cast(output.getType())) { + if (mlir::isa(type.getElementType())) { return false; } } @@ -398,15 +399,15 @@ bool IsQuantizedCallforDynamicRange(TF::PartitionedCallOp call_op) { bool IsQuantizedCallforStaticRange(TF::PartitionedCallOp call_op) { bool has_quantized_types = false; for (Value input : call_op.getArgs()) { - if (auto type = input.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (auto type = mlir::dyn_cast(input.getType())) { + if (mlir::isa(type.getElementType())) { has_quantized_types = true; } } } for (Value output : call_op.getOutput()) { - if (auto type = output.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (auto type = mlir::dyn_cast(output.getType())) { + if (mlir::isa(type.getElementType())) { has_quantized_types = true; } } @@ -616,7 +617,7 @@ std::string GetQuantizedFunctionName(StringRef func_name, bool ContainsFloatResultType(ArrayRef result_types) { for (auto current_type : result_types) { - if (current_type.dyn_cast().getElementType().isF32()) + if (mlir::dyn_cast(current_type).getElementType().isF32()) return true; } return false; @@ -644,7 +645,7 @@ class QuantizeFunctionPattern LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, PatternRewriter& rewriter) const override { - const auto f_attr = call_op.getFAttr().dyn_cast(); + const auto f_attr = mlir::dyn_cast(call_op.getFAttr()); // removeAttr will return nullptr if no attribute was removed. if (!call_op->removeAttr(kQuantTraitAttrName) || !f_attr) { return failure(); @@ -671,12 +672,12 @@ class QuantizeFunctionPattern SmallVector args; SmallVector qparam_args; for (Value arg : call_op.getArgs()) { - if (const auto arg_type = arg.getType().dyn_cast()) { + if (const auto arg_type = mlir::dyn_cast(arg.getType())) { QuantizedType qtype = - arg_type.getElementType().dyn_cast(); + mlir::dyn_cast(arg_type.getElementType()); if (!qtype) continue; - if (!qtype.isa()) { + if (!mlir::isa(qtype)) { return failure(); } Value scale, zero_point; @@ -693,12 +694,12 @@ class QuantizeFunctionPattern } for (Value result : call_op->getResults()) { - if (auto result_type = result.getType().dyn_cast()) { + if (auto result_type = mlir::dyn_cast(result.getType())) { QuantizedType qtype = - result_type.getElementType().dyn_cast(); + mlir::dyn_cast(result_type.getElementType()); if (!qtype) continue; - if (!qtype.isa()) { + if (!mlir::isa(qtype)) { return failure(); } Value scale, zero_point; @@ -717,12 +718,13 @@ class QuantizeFunctionPattern rewriter.setInsertionPoint(call_op); for (Value arg : call_op.getArgs()) { - TensorType arg_type = arg.getType().dyn_cast(); + TensorType arg_type = mlir::dyn_cast(arg.getType()); if (!arg_type) { args.push_back(arg); continue; } - QuantizedType qtype = arg_type.getElementType().dyn_cast(); + QuantizedType qtype = + mlir::dyn_cast(arg_type.getElementType()); if (!qtype) { args.push_back(arg); continue; @@ -730,15 +732,15 @@ class QuantizeFunctionPattern quantfork::StorageCastOp scast_op; if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { - ShapedType new_arg_type = ConvertIntToQint(arg_type.cast(), - rewriter.getContext()); + ShapedType new_arg_type = ConvertIntToQint( + mlir::cast(arg_type), rewriter.getContext()); if (!new_arg_type) { call_op->emitError( "Failed to convert the type to the corresponding qtype."); return failure(); } scast_op = rewriter.create( - arg.getLoc(), new_arg_type.cast(), arg); + arg.getLoc(), mlir::cast(new_arg_type), arg); } else { scast_op = rewriter.create( arg.getLoc(), arg_type.clone(qtype.getStorageType()), arg); @@ -761,20 +763,20 @@ class QuantizeFunctionPattern SmallVector result_types; for (Value result : call_op->getResults()) { - TensorType result_type = result.getType().dyn_cast(); + TensorType result_type = mlir::dyn_cast(result.getType()); if (!result_type) { result_types.push_back(result.getType()); continue; } QuantizedType qtype = - result_type.getElementType().dyn_cast(); + mlir::dyn_cast(result_type.getElementType()); if (!qtype) { result_types.push_back(result_type); continue; } if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { ShapedType new_result_type = ConvertIntToQint( - result_type.cast(), rewriter.getContext()); + mlir::cast(result_type), rewriter.getContext()); result_types.push_back(new_result_type); } else { result_types.push_back(result_type.clone(qtype.getStorageType())); @@ -871,13 +873,13 @@ class QuantizeFunctionPattern rewriter.setInsertionPointAfter(call_op); SmallVector result_types; for (Value result : call_op->getResults()) { - TensorType result_type = result.getType().dyn_cast(); + TensorType result_type = mlir::dyn_cast(result.getType()); if (!result_type) { result_types.push_back(result.getType()); continue; } QuantizedType qtype = - result_type.getElementType().dyn_cast(); + mlir::dyn_cast(result_type.getElementType()); if (!qtype) { result_types.push_back(result_type); continue; @@ -890,7 +892,7 @@ class QuantizeFunctionPattern auto module = call_op->getParentOfType(); SymbolTable symbol_table(module); - const auto f_attr = call_op.getFAttr().dyn_cast(); + const auto f_attr = mlir::dyn_cast(call_op.getFAttr()); const auto float_func = dyn_cast(symbol_table.lookup(f_attr.getValue())); rewriter.setInsertionPointAfter(float_func); @@ -973,14 +975,15 @@ class QuantizeConstPattern return failure(); } - ShapedType tensor_qtype = q_op.getResult().getType().cast(); + ShapedType tensor_qtype = + mlir::cast(q_op.getResult().getType()); Attribute tensor_proto_attr = Quantize(attr, tensor_qtype); if (!tensor_proto_attr) { return failure(); } - Type storage_type = - tensor_qtype.getElementType().cast().getStorageType(); + Type storage_type = mlir::cast(tensor_qtype.getElementType()) + .getStorageType(); ShapedType new_type = tensor_qtype.clone(storage_type); Location loc = q_op.getArg().getLoc(); @@ -991,14 +994,14 @@ class QuantizeConstPattern // workaround. tensorflow::TensorProto tensor_proto; if (!mlir::tfg::ConvertToTensorProto( - tensor_proto_attr.cast(), &tensor_proto) + mlir::cast(tensor_proto_attr), &tensor_proto) .ok()) { return failure(); } - const int bit_width = tensor_qtype.getElementType() - .dyn_cast() - .getStorageTypeIntegralWidth(); + const int bit_width = + mlir::dyn_cast(tensor_qtype.getElementType()) + .getStorageTypeIntegralWidth(); tensor_proto.set_dtype((bit_width == 8) ? tensorflow::DT_QINT8 : tensorflow::DT_QINT32); @@ -1033,8 +1036,9 @@ class RestoreWeightShapePattern int weight_operand_idx = 1; Operation* weight_op = op.getOperand(weight_operand_idx).getDefiningOp(); - auto weight_type = weight_op->getResult(0).getType().dyn_cast(); - auto input_type = op.getOperand(0).getType().dyn_cast(); + auto weight_type = + mlir::dyn_cast(weight_op->getResult(0).getType()); + auto input_type = mlir::dyn_cast(op.getOperand(0).getType()); llvm::ArrayRef weight_shape = weight_type.getShape(); llvm::ArrayRef input_shape = input_type.getShape(); @@ -1073,7 +1077,7 @@ class RestoreWeightShapePattern LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, PatternRewriter& rewriter) const override { - const auto f_attr = call_op.getFAttr().dyn_cast(); + const auto f_attr = mlir::dyn_cast(call_op.getFAttr()); StringRef function_name = f_attr.getValue(); // TODO(b/228928859): Improve the getter function to match attributes rather // than function name. @@ -1106,7 +1110,8 @@ class QuantizationSummary { module_.walk([&](Operation* op) { if (auto call_op = llvm::dyn_cast_or_null(op)) { - const auto f_attr = call_op.getFAttr().dyn_cast(); + const auto f_attr = + mlir::dyn_cast(call_op.getFAttr()); if (!f_attr) return; StringRef func_name = f_attr.getValue(); if (func_name.starts_with(kQuantizedFuncPrefix)) { @@ -1227,7 +1232,7 @@ class QuantizationSummary { } // Use the first op as the representative name. - return quantized_ops.front().cast().getValue(); + return mlir::cast(quantized_ops.front()).getValue(); } bool IsInCompsiteFunction(Operation* op) { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc index 374d687428ee3e..b202798dffe9d0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -77,8 +78,9 @@ void PrepareXlaConvParams(OpBuilder &builder, Location loc, ArrayAttr strides, SmallVector lhs_dilation_values(num_dims - 2, 1); SmallVector stride_values, rhs_dilation_values; for (int64_t i : llvm::seq(1, num_dims - 1)) { - stride_values.push_back(strides[i].cast().getInt()); - rhs_dilation_values.push_back(dilations[i].cast().getInt()); + stride_values.push_back(mlir::cast(strides[i]).getInt()); + rhs_dilation_values.push_back( + mlir::cast(dilations[i]).getInt()); } window_strides = Create1DConstValue(builder, loc, stride_values); lhs_dilation = Create1DConstValue(builder, loc, lhs_dilation_values); @@ -96,7 +98,7 @@ Value CreateZeroPointPartialOffset(OpBuilder &builder, Location loc, return CreateScalarConstValue(builder, loc, 0); } - auto shape = tensor.getType().template cast(); + auto shape = mlir::cast(tensor.getType()); SmallVector non_output_indices; for (int64_t i : llvm::seq(0, shape.getRank())) { if (absl::c_count(output_dims, i) == 0) { @@ -108,7 +110,7 @@ Value CreateZeroPointPartialOffset(OpBuilder &builder, Location loc, Create1DConstValue(builder, loc, non_output_indices); auto zp = CreateScalarConstValue(builder, loc, other_tensor_zp); - TensorType tensor_type = tensor.getType().dyn_cast(); + TensorType tensor_type = mlir::dyn_cast(tensor.getType()); Value tensor_i32 = builder.create( loc, tensor_type.clone(builder.getIntegerType(32)), tensor); auto reduced = @@ -136,7 +138,7 @@ Value MergeZeroPointOffset(OpBuilder &builder, Location loc, Value weight, int8_t input_zp, int8_t weight_zp, Value zp_input_contribution, Value zp_weight_contribution) { - auto weight_shape = weight.getType().template cast(); + auto weight_shape = mlir::cast(weight.getType()); SmallVector weight_non_output_indices; for (auto i : llvm::seq(0, weight_shape.getRank())) { if (absl::c_count(weight_output_dims, i) == 0) { @@ -498,7 +500,7 @@ Value CreateZeroPointPartialOffsetXlaDotV2( return CreateScalarConstValue(builder, loc, 0); } - auto shape = tensor.getType().template cast(); + auto shape = mlir::cast(tensor.getType()); SmallVector tensor_shape; for (auto v : shape.getShape()) { tensor_shape.push_back(v); @@ -506,7 +508,7 @@ Value CreateZeroPointPartialOffsetXlaDotV2( auto zp = CreateScalarConstValue(builder, loc, other_tensor_zp); - TensorType tensor_type = tensor.getType().dyn_cast(); + TensorType tensor_type = mlir::dyn_cast(tensor.getType()); Value tensor_i32 = builder.create( loc, tensor_type.clone(builder.getIntegerType(32)), tensor); @@ -596,7 +598,7 @@ Value CalculateZeroPointOffsetXLADotV2(OpBuilder &builder, Location loc, Value zp_weight_contribution = CreateZeroPointPartialOffsetXlaDotV2( builder, loc, weight, input_zp, dnums, /*is_lhs=*/false, output_rank); - auto weight_shape = weight.getType().template cast(); + auto weight_shape = mlir::cast(weight.getType()); absl::flat_hash_set rhs_contracting_dims; for (auto dim : dnums.rhs_contracting_dimensions()) { @@ -711,8 +713,8 @@ Value CreateXlaConvOpFromTfConv2dOp(OpBuilder &builder, Location loc, ArrayAttr dilations, StringAttr conv_padding, ArrayAttr explicit_paddings) { - auto input_shape = input.getType().template cast(); - auto filter_shape = filter.getType().template cast(); + auto input_shape = mlir::cast(input.getType()); + auto filter_shape = mlir::cast(filter.getType()); if (!input_shape.hasRank() || input_shape.getRank() != 4 || !filter_shape.hasRank() || filter_shape.getRank() != 4) { emitError(loc, "input and filter are expected to be 4D tensors"); @@ -731,8 +733,8 @@ Value CreateXlaConvOpFromTfDepthwiseConv2dOp( OpBuilder &builder, Location loc, Value input, Value filter, Value input_zp, Value conv_output, ArrayAttr strides, ArrayAttr dilations, StringAttr conv_padding, ArrayAttr explicit_paddings) { - auto input_shape = input.getType().template cast(); - auto filter_shape = filter.getType().template cast(); + auto input_shape = mlir::cast(input.getType()); + auto filter_shape = mlir::cast(filter.getType()); if (!input_shape.hasRank() || input_shape.getRank() != 4 || !filter_shape.hasRank() || filter_shape.getRank() != 4) { emitError(loc, "input and filter are expected to be 4D tensors"); @@ -759,8 +761,8 @@ Value CreateXlaConvOpFromTfConv3dOp(OpBuilder &builder, Location loc, Value conv_output, ArrayAttr strides, ArrayAttr dilations, StringAttr conv_padding) { - auto input_shape = input.getType().template cast(); - auto filter_shape = filter.getType().template cast(); + auto input_shape = mlir::cast(input.getType()); + auto filter_shape = mlir::cast(filter.getType()); if (!input_shape.hasRank() || input_shape.getRank() != 5 || !filter_shape.hasRank() || filter_shape.getRank() != 5) { emitError(loc, "input and filter are expected to be 5D tensors"); @@ -819,7 +821,7 @@ Value CreateXlaDotV2Op(OpBuilder &builder, Location loc, Value input, Value zp_offset = CalculateZeroPointOffsetXLADotV2( builder, loc, input, weight, input_zp_value, weight_zp_value, dnums, - output.getType().template cast().getRank()); + mlir::cast(output.getType()).getRank()); return builder.create(loc, dot_result, zp_offset); } @@ -891,8 +893,8 @@ GetBroadcastShapesForBatchMatmul(ShapedType input_type, // function, except BroadcastTo, are expected to be folded. void BroadcastBatchDimensionsForBatchMatMul(OpBuilder &builder, Location loc, Value &input, Value &weight) { - ShapedType input_type = input.getType().template cast(); - ShapedType weight_type = weight.getType().template cast(); + ShapedType input_type = mlir::cast(input.getType()); + ShapedType weight_type = mlir::cast(weight.getType()); const int32_t input_rank = input_type.getRank(); const int32_t weight_rank = weight_type.getRank(); const int32_t broadcasted_rank = std::max(input_rank, weight_rank); @@ -984,7 +986,7 @@ Value CreateXlaDotV2OpFromTfBatchMatMulOp(OpBuilder &builder, Location loc, BroadcastBatchDimensionsForBatchMatMul(builder, loc, input, weight); // Both input and weight have the same rank after broadcasting. - ShapedType weight_shape = weight.getType().template cast(); + ShapedType weight_shape = mlir::cast(weight.getType()); int num_batch_dim = weight_shape.getRank() - 2; // Transpose and constant-fold the weight if needed. @@ -1016,7 +1018,7 @@ Value CreateXlaDotV2OpFromTfBatchMatMulOp(OpBuilder &builder, Location loc, // Check if the given value is a ranked type with specified integer width. bool IsRankedInt(Value value, const int integer_width) { - ShapedType value_type = value.getType().template cast(); + ShapedType value_type = mlir::cast(value.getType()); if (!value_type.hasRank()) return false; if (!value_type.getElementType().isInteger(integer_width)) return false; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.td index e33e226be35515..4928bafc7490eb 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.td @@ -61,7 +61,18 @@ def TF_CustomAggregatorOp : TF_Op<"CustomAggregator", [Pure]> { let arguments = (ins TensorOf<[TF_Float32]>:$input, - StrAttr:$id + // The unique id of this `CustomAggregator` op. + StrAttr:$id, + // The integer value of the enforcing `CalibrationMethod`. + I32Attr:$calibration_method, + // The number of histogram bins. + I32Attr:$num_bins, + // Min percentile to be included in the selected range, only used in the + // `HISTOGRAM_PERCENTILE` method. + F32Attr:$min_percentile, + // Max percentile to be included in the selected range, only used in the + // `HISTOGRAM_PERCENTILE` method. + F32Attr:$max_percentile ); let results = (outs @@ -72,6 +83,20 @@ def TF_CustomAggregatorOp : TF_Op<"CustomAggregator", [Pure]> { ); } +def TF_CalibrationStatisticsSaverOp : TF_Op<"CalibrationStatisticsSaver", []> { + let summary = "Aggregates and saves calibration statistics."; + + let arguments = (ins + Variadic>:$inputs, + + StrAttr:$output_file_path, + StrArrayAttr:$ids, + I32ArrayAttr:$calibration_methods + ); + + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; +} + def TF_DumpTensorOp : TF_Op<"DumpTensor", []> { let summary = "Dump tensor proto."; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index 78a8321f9f87d4..c0a472ca8f2e26 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -45,6 +45,8 @@ cc_library( "//tensorflow/compiler/mlir/quantization/stablehlo/cc:saved_model_export", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:saved_model_import", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:types", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:weight_only_ptq", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:component", "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:statistics", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py index 08ff75ac802613..b44c788bc10f7b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py @@ -346,25 +346,10 @@ def test_drq_per_channel_for_non_uniform_opset_raises_value_error( self._input_saved_model_path, quantization_options=options ) - @test_util.run_in_graph_and_eager_modes def test_force_graph_mode_calibration(self): - input_type = dtypes.int32 - input_placeholder = self._create_and_save_tf1_gather_model( - self._input_saved_model_path, - signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - tags={tag_constants.SERVING}, - input_key='x', - output_key='output', - input_type=input_type, - ) + model = self.SimpleModel() - data_gen = self._create_data_generator( - input_key='x', - shape=input_placeholder.shape, - minval=0, - maxval=10, - dtype=input_type, - ) + saved_model_save.save(model, self._input_saved_model_path) options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( @@ -383,7 +368,7 @@ def test_force_graph_mode_calibration(self): quantize_model.quantize( self._input_saved_model_path, quantization_options=options, - representative_dataset=data_gen, + representative_dataset=self._simple_model_data_gen(), ) finally: # Restore the logger verbosity. @@ -2725,6 +2710,116 @@ def data_gen() -> repr_dataset.RepresentativeDataset: self.assertAllClose(new_outputs, got_outputs, atol=0.097) self.assertAllClose(new_outputs, expected_outputs, atol=0.057) + def test_reuse_calibration_data(self): + model = self._create_simple_gather_and_conv_model( + dtypes.int32, filter_shape=(2, 3, 3, 1024) + ) + saved_model_save.save(model, self._input_saved_model_path) + + data_gen = self._create_data_generator( + input_key='input_tensor', + shape=[50], + minval=0, + maxval=64, + dtype=dtypes.int32, + ) + + tags = {tag_constants.SERVING} + + calibration_data_dir = self.create_tempdir('calibration_data').full_path + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + preset_method=_PresetMethod.METHOD_STATIC_RANGE_INT8 + ), + tags=tags, + signature_keys=['serving_default'], + op_set=quant_opts_pb2.XLA, + force_graph_mode_calibration=True, + calibration_options=stablehlo_quant_config_pb2.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX, + calibration_data_dir=calibration_data_dir, + ), + ) + + # Run quantization the first time, calibration is expected to be run. + with self.assertLogs(level='INFO') as info_logs: + # Save the logger verbosity. + prev_log_level = logging.get_verbosity() + logging.set_verbosity(logging.INFO) + try: + converted_model1 = quantize_model.quantize( + self._input_saved_model_path, + self._output_saved_model_path, + quantization_options, + representative_dataset=data_gen, + ) + finally: + # Restore the logger verbosity. + logging.set_verbosity(prev_log_level) + + self.assertNotEmpty(info_logs.records) + self.assertTrue( + self._any_log_contains( + 'Calibration step is executed in graph mode.', + info_logs.records, + ) + ) + self.assertIsNotNone(converted_model1) + self.assertCountEqual( + converted_model1.signatures._signatures.keys(), {'serving_default'} + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags( + tags + ).graph_def + self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2')) + + # Run quantization the first time, calibration is expected to be skipped. + with self.assertLogs(level='INFO') as info_logs: + # Save the logger verbosity. + prev_log_level = logging.get_verbosity() + logging.set_verbosity(logging.INFO) + try: + converted_model2 = quantize_model.quantize( + self._input_saved_model_path, + self._output_saved_model_path, + quantization_options, + representative_dataset=data_gen, + overwrite_output_directory=True, + ) + finally: + # Restore the logger verbosity. + logging.set_verbosity(prev_log_level) + + self.assertNotEmpty(info_logs.records) + self.assertFalse( + self._any_log_contains( + 'Calibration step is executed in graph mode.', + info_logs.records, + ) + ) + self.assertIsNotNone(converted_model2) + self.assertCountEqual( + converted_model2.signatures._signatures.keys(), {'serving_default'} + ) + + # Expect two models to produce the same results. + test_data = ops.convert_to_tensor( + np.random.uniform(low=0, high=64, size=(32)).astype( + dtypes.int32.as_numpy_dtype + ) + ) + new_outputs_1 = converted_model1.signatures['serving_default']( + input_tensor=test_data + )['output'] + new_outputs_2 = converted_model2.signatures['serving_default']( + input_tensor=test_data + )['output'] + self.assertAllClose(new_outputs_1, new_outputs_2) + @test_util.run_in_graph_and_eager_modes def test_function_alias_preserved(self): model = self._create_conv2d_model( @@ -5406,6 +5501,7 @@ def test_einsum_model( @parameterized.named_parameters( ('to_xla_per_tensor', quant_opts_pb2.XLA, False), + ('stablehlo_per_channel', quant_opts_pb2.STABLEHLO, True), ) @test_util.run_in_graph_and_eager_modes def test_matmul_model( @@ -5447,8 +5543,14 @@ def test_matmul_model( ) output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + if target_opset == quant_opts_pb2.XLA: + self.assertTrue(self._contains_op(output_graphdef, 'XlaDotV2')) + elif target_opset == quant_opts_pb2.STABLEHLO: + # This is to verify the invocation of StableHLO quantizer works. More + # thorough functional tests are in StableHLO quantizer directory. + self.assertTrue(self._contains_op(output_graphdef, 'XlaCallModule')) + # Due to other meta data, the compression is not exactly 1/4. - self.assertTrue(self._contains_op(output_graphdef, 'XlaDotV2')) self.assertLess( testing.get_size_ratio( self._output_saved_model_path, self._input_saved_model_path @@ -5458,6 +5560,7 @@ def test_matmul_model( @parameterized.named_parameters( ('to_xla_per_tensor', quant_opts_pb2.XLA, False), + ('stablehlo_per_channel', quant_opts_pb2.STABLEHLO, True), # TODO: b/289761265 - [Converter Component][TF-Quantizer] Improve Weight- # only Quantization # Enable this back once new weight-only quantizer is supported for per- @@ -5517,7 +5620,7 @@ def test_conv_model( 0.3, ) - if enable_per_channel_quantization: + if enable_per_channel_quantization and target_opset == quant_opts_pb2.XLA: per_channel_size_attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( shape=[ @@ -5536,6 +5639,12 @@ def test_conv_model( output_graphdef, 'Const', '_output_shapes', per_channel_size_attr ) ) + if target_opset == quant_opts_pb2.XLA: + self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2')) + elif target_opset == quant_opts_pb2.STABLEHLO: + # This is to verify the invocation of StableHLO quantizer works. More + # thorough functional tests are in StableHLO quantizer directory. + self.assertTrue(self._contains_op(output_graphdef, 'XlaCallModule')) input_tensor = array_ops.constant( np.random.uniform(low=0, high=0.1, size=input_shape), @@ -6211,25 +6320,25 @@ class CalibrationOptionsTest(quantize_model_test_base.QuantizedModelTest): stablehlo_quant_config_pb2.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, calibration_parameters=stablehlo_quant_config_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=32, + num_bins=32, ), ), stablehlo_quant_config_pb2.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, calibration_parameters=stablehlo_quant_config_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=32, + num_bins=32, ), ), stablehlo_quant_config_pb2.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, calibration_parameters=stablehlo_quant_config_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=32, + num_bins=32, ), ), stablehlo_quant_config_pb2.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, calibration_parameters=stablehlo_quant_config_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=32, + num_bins=32, ), ), ], @@ -6376,7 +6485,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: 'default_calibration_options': stablehlo_quant_config_pb2.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, calibration_parameters=stablehlo_quant_config_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=256, + num_bins=512, min_percentile=0.001, max_percentile=99.999, ), @@ -6390,7 +6499,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: 'default_calibration_options': stablehlo_quant_config_pb2.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, calibration_parameters=stablehlo_quant_config_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=256 + num_bins=512 ), ), }, @@ -6402,7 +6511,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: 'default_calibration_options': stablehlo_quant_config_pb2.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, calibration_parameters=stablehlo_quant_config_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=256 + num_bins=512 ), ), }, @@ -6414,7 +6523,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: 'default_calibration_options': stablehlo_quant_config_pb2.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, calibration_parameters=stablehlo_quant_config_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=256 + num_bins=512 ), ), }, @@ -6441,8 +6550,8 @@ def test_default_calibration_options( default_calibration_options.calibration_method, ) self.assertEqual( - quant_opts.calibration_options.calibration_parameters.initial_num_bins, - default_calibration_options.calibration_parameters.initial_num_bins, + quant_opts.calibration_options.calibration_parameters.num_bins, + default_calibration_options.calibration_parameters.num_bins, ) self.assertEqual( quant_opts.calibration_options.calibration_parameters.min_percentile, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index 9f4621360e2e89..e38310879184ef 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/context.h" @@ -46,6 +47,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" @@ -77,13 +79,14 @@ using ::mlir::quant::stablehlo::GetFunctionAliases; using ::mlir::quant::stablehlo::kExportStepSuffix; using ::mlir::quant::stablehlo::PostCalibrationComponent; using ::mlir::quant::stablehlo::PreCalibrationComponent; +using ::mlir::quant::stablehlo::RunCalibrationPasses; using ::mlir::quant::stablehlo::UpdateFunctionAliases; +using ::mlir::quant::stablehlo::WeightOnlyPtqComponent; using ::stablehlo::quantization::AddCalibrationStatistics; using ::stablehlo::quantization::ChangeToQuantizedFilename; using ::stablehlo::quantization::DebuggerConfig; -using ::stablehlo::quantization::DisableDebugging; -using ::stablehlo::quantization::EnableDebugging; using ::stablehlo::quantization::ExpandPresets; +using ::stablehlo::quantization::IsCalibrationRequired; using ::stablehlo::quantization::PopulateDefaults; using ::stablehlo::quantization::QuantizationConfig; using ::stablehlo::quantization::io::CreateTmpDir; @@ -156,13 +159,17 @@ absl::StatusOr ModuleOpToExportedModel( absl::StatusOr ExportCalibrationModel( mlir::ModuleOp module_op, mlir::MLIRContext *context, const QuantizationOptions &quantization_options, - const absl::flat_hash_map &function_aliases) { + const absl::flat_hash_map &function_aliases, + absl::string_view calibration_data_dir) { // Clone ModuleOp and function aliases so changes in this pipeline won't // be reflected in the original values. mlir::OwningOpRef cloned_module_ref(module_op.clone()); - // Disable DumpTensor ops when running calibration. - DisableDebugging(*cloned_module_ref); + TF_RETURN_IF_ERROR( + RunCalibrationPasses(*cloned_module_ref, *context, calibration_data_dir, + quantization_options.calibration_options() + .force_regenerate_calibration_data())); + if (!IsCalibrationRequired(*cloned_module_ref)) return ExportedModel(); absl::StatusOr exported_model = ModuleOpToExportedModel( *cloned_module_ref, context, kTfQuantPtqPreCalibrationStepName, @@ -177,6 +184,27 @@ absl::StatusOr ExportCalibrationModel( return *exported_model; } +absl::StatusOr ExportDebuggingModel( + mlir::ModuleOp module_op, mlir::MLIRContext *context, + const QuantizationOptions &quantization_options, + const absl::flat_hash_map &function_aliases) { + // Clone ModuleOp and function aliases so changes in this pipeline won't + // be reflected in the original values. + mlir::OwningOpRef cloned_module_ref(module_op.clone()); + + absl::StatusOr exported_model = ModuleOpToExportedModel( + *cloned_module_ref, context, kTfQuantPtqPreCalibrationStepName, + /*unfreeze_constants=*/!quantization_options.freeze_all_variables(), + function_aliases); + if (!exported_model.status().ok()) { + return absl::InternalError( + absl::StrCat("Failed to export debugging model: ", + exported_model.status().message())); + } + + return *exported_model; +} + QuantizationConfig GetQuantizationConfigForStaticRangePtq( const QuantizationOptions &quantization_options) { QuantizationConfig quantization_config{}; @@ -197,10 +225,25 @@ QuantizationConfig GetQuantizationConfigForStaticRangePtq( return ExpandPresets(PopulateDefaults(quantization_config)); } +QuantizationConfig GetQuantizationConfigForWeightOnlyPtq( + const QuantizationOptions &quantization_options) { + QuantizationConfig quantization_config{}; + quantization_config.mutable_weight_only_ptq_preset(); + // When targeting server TPUs quantized types should be unpacked into + // integer ops. + quantization_config.mutable_pipeline_config()->set_unpack_quantized_types( + true); + *quantization_config.mutable_debugger_config() = + quantization_options.debugger_config(); + + return ExpandPresets(PopulateDefaults(quantization_config)); +} + absl::StatusOr QuantizePtqModelPreCalibrationImpl( mlir::ModuleOp module_op, mlir::MLIRContext *context, const QuantizationOptions &quantization_options, - const absl::flat_hash_map &function_aliases) { + const absl::flat_hash_map &function_aliases, + absl::string_view calibration_data_dir) { const bool is_stablehlo = quantization_options.op_set() == OpSet::STABLEHLO; // Use StableHLO Quantizer option if opset is specified. if (is_stablehlo) { @@ -221,7 +264,7 @@ absl::StatusOr QuantizePtqModelPreCalibrationImpl( } return ExportCalibrationModel(module_op, context, quantization_options, - function_aliases); + function_aliases, calibration_data_dir); } absl::StatusOr QuantizePtqModelPostCalibrationImpl( @@ -358,6 +401,7 @@ absl::StatusOr QuantizeWeightOnly( "Failed to get function alias: ", function_aliases.status().message())); } + const bool is_stablehlo = quantization_options.op_set() == OpSet::STABLEHLO; absl::StatusOr> module = ImportAndPreprocessSavedModel( saved_model_path, @@ -365,7 +409,8 @@ absl::StatusOr QuantizeWeightOnly( quantization_options.signature_keys().end()}, {quantization_options.tags().begin(), quantization_options.tags().end()}, - context.get(), /*is_inliner_run=*/true, /*run_tf_to_stablehlo=*/false, + context.get(), /*is_inliner_run=*/true, + /*run_tf_to_stablehlo=*/is_stablehlo, /*deserialize_xla_call_module=*/false, *function_aliases); if (!module.status().ok()) { return absl::InternalError( @@ -374,14 +419,24 @@ absl::StatusOr QuantizeWeightOnly( } mlir::OwningOpRef module_ref = std::move(module).value(); - TF_RETURN_IF_ERROR(RunPasses( - kTfQuantWeightOnlyStepName, - /*add_passes_func=*/ - [&quantization_options](mlir::PassManager &pm) { - AddQuantizeWeightOnlyPasses(pm, quantization_options, - kTfQuantWeightOnlyStepName); - }, - *context, *module_ref)); + // Use StableHLO Quantizer option if opset is specified. + if (is_stablehlo) { + const QuantizationConfig quantization_config = + GetQuantizationConfigForWeightOnlyPtq(quantization_options); + + WeightOnlyPtqComponent weight_only_ptq_component(context.get()); + TF_ASSIGN_OR_RETURN(*module_ref, weight_only_ptq_component.Run( + *module_ref, quantization_config)); + } else { + TF_RETURN_IF_ERROR(RunPasses( + kTfQuantWeightOnlyStepName, + /*add_passes_func=*/ + [&quantization_options](mlir::PassManager &pm) { + AddQuantizeWeightOnlyPasses(pm, quantization_options, + kTfQuantWeightOnlyStepName); + }, + *context, *module_ref)); + } return ModuleOpToExportedModel( *module_ref, context.get(), kTfQuantWeightOnlyStepName, @@ -422,27 +477,34 @@ absl::StatusOr QuantizeStaticRangePtq( } mlir::OwningOpRef module_ref = std::move(module).value(); - TF_ASSIGN_OR_RETURN( - absl::StatusOr pre_calibration_exported_model, - QuantizePtqModelPreCalibrationImpl( - *module_ref, context.get(), quantization_options, *function_aliases)); + std::string calibration_data_dir = + quantization_options.calibration_options().calibration_data_dir(); + if (calibration_data_dir.empty()) { + TF_ASSIGN_OR_RETURN(calibration_data_dir, CreateTmpDir()); + } - TF_ASSIGN_OR_RETURN( - const absl::StatusOr precalibrated_saved_model_dir, - CreateTmpDir()); + TF_ASSIGN_OR_RETURN(ExportedModel calibration_exported_model, + QuantizePtqModelPreCalibrationImpl( + *module_ref, context.get(), quantization_options, + *function_aliases, calibration_data_dir)); - py_function_library.SaveExportedModel( - *precalibrated_saved_model_dir, *pre_calibration_exported_model, - saved_model_path, tags, signature_def_map); + // Save and run the calibration model. + if (calibration_exported_model.has_graph_def()) { + TF_ASSIGN_OR_RETURN(std::string calibration_saved_model_dir, + CreateTmpDir()); + py_function_library.SaveExportedModel( + calibration_saved_model_dir, calibration_exported_model, + saved_model_path, tags, signature_def_map); - py_function_library.RunCalibration( - *precalibrated_saved_model_dir, signature_keys, tags, - quantization_options.force_graph_mode_calibration(), - representative_dataset_file_map_serialized); + py_function_library.RunCalibration( + calibration_saved_model_dir, signature_keys, tags, + quantization_options.force_graph_mode_calibration(), + representative_dataset_file_map_serialized); + } if (absl::Status status = AddCalibrationStatistics( - *module_ref, quantization_options.calibration_options(), - py_function_library); + *module_ref, calibration_data_dir, + quantization_options.calibration_options(), py_function_library); !status.ok()) { LOG(WARNING) << "Some CustomAggregator ops do not have min or max " "values. Parts of the graph are not quantized. " @@ -459,14 +521,17 @@ absl::StatusOr QuantizeStaticRangePtq( if (quantization_options.has_debugger_config() && quantization_options.debugger_config().debugger_type() == DebuggerConfig::DEBUGGER_TYPE_WHOLE_MODEL) { - EnableDebugging(*pre_calibration_exported_model); + TF_ASSIGN_OR_RETURN( + ExportedModel debugging_exported_model, + ExportDebuggingModel(*module_ref, context.get(), quantization_options, + *function_aliases)); ChangeToQuantizedFilename(*module_ref); absl::string_view unquantized_dump_model_path = quantization_options.debugger_config().unquantized_dump_model_path(); py_function_library.SaveExportedModel( - unquantized_dump_model_path, *pre_calibration_exported_model, - saved_model_path, tags, signature_def_map); + unquantized_dump_model_path, debugging_exported_model, saved_model_path, + tags, signature_def_map); } return QuantizePtqModelPostCalibrationImpl( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index e0eeca13d92f20..f7dec2d2a5dee7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -607,8 +607,8 @@ def _populate_calibration_options( calib_opts.calibration_method == _CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE ): - if not calib_opts.calibration_parameters.initial_num_bins: - calib_opts.calibration_parameters.initial_num_bins = 256 + if not calib_opts.calibration_parameters.num_bins: + calib_opts.calibration_parameters.num_bins = 512 if not calib_opts.calibration_parameters.min_percentile: calib_opts.calibration_parameters.min_percentile = 0.001 if not calib_opts.calibration_parameters.max_percentile: @@ -632,8 +632,14 @@ def _populate_calibration_options( f' methods. calibration_method={calib_opts.calibration_method}' ) - if not calib_opts.calibration_parameters.initial_num_bins: - calib_opts.calibration_parameters.initial_num_bins = 256 + if not calib_opts.calibration_parameters.num_bins: + calib_opts.calibration_parameters.num_bins = 512 + + if calib_opts.calibration_data_dir: + save_model.create_empty_output_dir( + calib_opts.calibration_data_dir, + overwrite=calib_opts.force_regenerate_calibration_data, + ) def _populate_quantization_options_default_values( @@ -735,24 +741,24 @@ def _populate_quantization_options_default_values( if (quantization_options.op_set == quant_opts_pb2.OpSet.STABLEHLO) and ( quantization_options.quantization_method.preset_method != _PresetMethod.METHOD_STATIC_RANGE_INT8 + and quantization_options.quantization_method.preset_method + != _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8 ): raise ValueError( 'StableHLO quantized opset currently only supports static range' - ' quantization via TF Quantizer.' + ' quantization and weight-only quantizationvia TF Quantizer.' ) - if quantization_options.HasField('debugger_config'): - # Set `force_graph_mode_calibration` to True to avoid skipping op execution, - # which are not connected to return ops, during calibration execution. - # Setting `force_graph_mode_calibration` to True enables execution of the - # model in graph mode (not eager mode). - logging.debug( - 'Setting `force_graph_mode_calibration = True` to ensure the debugging ' - 'model is executed in graph mode during calibration, rather than eager ' - 'mode.' - ) - quantization_options.force_graph_mode_calibration = True + # Set `force_graph_mode_calibration` to True to avoid skipping op execution, + # which are not connected to return ops, during calibration execution. + # TODO: b/335031954 - Bring back support to run calibration in Eager mode. + logging.debug( + 'Setting `force_graph_mode_calibration = True` to ensure the calibration' + ' mode is executed properly.' + ) + quantization_options.force_graph_mode_calibration = True + if quantization_options.HasField('debugger_config'): if not quantization_options.debugger_config.log_dir_path: quantization_options.debugger_config.log_dir_path = '/tmp/dumps' diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/save_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/save_model.py index 4b4ac4f65fe157..87ad7a11f2e677 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/save_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/save_model.py @@ -128,7 +128,9 @@ def _restore_output_tensor_names( return graph_def -def _create_empty_output_dir(output_directory: str) -> None: +def create_empty_output_dir( + output_directory: str, overwrite: bool = True +) -> None: """Creates the `output_directory`. If `output_directory` already exists, it recursively deletes all contents @@ -138,10 +140,11 @@ def _create_empty_output_dir(output_directory: str) -> None: Args: output_directory: Output directory. + overwrite: Where to clean the output directory if exists. """ - if file_io.file_exists_v2(output_directory): + if overwrite and file_io.file_exists_v2(output_directory): logging.info( - 'Deleting existing directory for quantized model output: %s .', + 'Deleting existing output directory: %s .', output_directory, ) file_io.delete_recursively_v2(output_directory) @@ -297,7 +300,7 @@ def save_model_v1( ValueError iff the graph does not contain a valid signature or the file prefix tensor is not found in the graph. """ - _create_empty_output_dir(output_dir) + create_empty_output_dir(output_dir) v1_builder = builder.SavedModelBuilder(output_dir) graph_def = _restore_output_tensor_names(graph_def) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc index 0e756021844a5c..b91e4a23613341 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc @@ -157,7 +157,6 @@ void AddQuantizePtqPreCalibrationPasses( pm.addNestedPass( mlir::quant::CreateInsertCustomAggregationOpsPass( quantization_options.calibration_options())); - pm.addPass(mlir::quant::CreateIssueIDsOfCustomAggregationOpsPass()); } void AddQuantizePtqPostCalibrationPasses( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc index c8db1da7adace4..1d150ce7a648ea 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" +#include #include #include #include @@ -23,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/ArrayRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -69,7 +71,9 @@ void AddUnfuseMhloOpsPasses(mlir::PassManager& pm) { // Converts TF SavedModel to StableHLO module. The input TF SavedModel can have // StableHLO module serialized into a XlaCallModuleOp. (ex: JAX/PyTorch models) -void AddTFToStablehloPasses(mlir::PassManager& pm) { +void AddTFToStablehloPasses( + mlir::PassManager& pm, + llvm::ArrayRef> input_arg_shapes) { pm.addPass(mlir::odml::CreateRenameEntrypointToMainPass()); // TODO: b/230572023 - Consider improving shape inference for While op instead // of dropping the attribute. This need not be correct for models not trained @@ -97,7 +101,7 @@ void AddTFToStablehloPasses(mlir::PassManager& pm) { pm.addPass(mlir::createSymbolDCEPass()); pm.addPass(mlir::createCanonicalizerPass()); // Propagates shapes on the TensorFlow graph. - pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass(input_arg_shapes)); pm.addPass(mlir::createCanonicalizerPass()); pm.addNestedPass( mlir::TFDevice::CreateDecomposeResourceOpsPass()); @@ -110,7 +114,7 @@ void AddTFToStablehloPasses(mlir::PassManager& pm) { // Generic MLIR optimization passes. pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass(input_arg_shapes)); // Legalizes TF UniformQuantized types into MHLO. Part of the official // TF/XLA bridge component. @@ -120,9 +124,9 @@ void AddTFToStablehloPasses(mlir::PassManager& pm) { // TF -> StableHLO legalization. // Skip StatefulPartitionedCall to preserve aliased functions. - mlir::odml::AddLegalizeTFToStablehloPasses( - pm, /*skip_quantization_ops=*/true, - /*skip_resize=*/false, /*skip_stateful_partitioned_call=*/true); + mlir::odml::AddLegalizeTFToStablehloPasses(pm, /*skip_quantization_ops=*/true, + /*skip_resize=*/false, + /*skip_partitioned_calls=*/true); // StableHLO -> MHLO legalization for MHLO optimization. pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); // Rewrites legacy StableHLO ops. @@ -137,7 +141,8 @@ absl::Status PreprocessAndFreezeGraph( const absl::flat_hash_set& noinline_functions, mlir::ModuleOp module_op, mlir::MLIRContext* context, std::optional session, const bool run_tf_to_stablehlo, - const bool deserialize_xla_call_module) { + const bool deserialize_xla_call_module, + llvm::ArrayRef> input_arg_shapes) { mlir::PassManager pm_before_freezing_variables(context); mlir::StatusScopedDiagnosticHandler statusHandler(module_op.getContext(), /*propagate=*/true); @@ -169,7 +174,7 @@ absl::Status PreprocessAndFreezeGraph( if (run_tf_to_stablehlo) { // AddLegalizeTFToStablehloPasses expects frozen TF variables when // legalizing to stablehlo.constant. - AddTFToStablehloPasses(pm_after_freezing_variables); + AddTFToStablehloPasses(pm_after_freezing_variables, input_arg_shapes); } if (deserialize_xla_call_module) { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h index 740dca6c7b106b..878b3ebdb27968 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PREPROCESS_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PREPROCESS_H_ +#include #include #include @@ -45,7 +46,8 @@ absl::Status PreprocessAndFreezeGraph( const absl::flat_hash_set& noinline_functions, mlir::ModuleOp module_op, mlir::MLIRContext* context, std::optional session, bool run_tf_to_stablehlo, - bool deserialize_xla_call_module); + bool deserialize_xla_call_module, + llvm::ArrayRef> input_arg_shapes = {}); // Overload of `PreprocessAndFreezeGraph` that uses the default MLIR dump file // prefix. @@ -56,10 +58,15 @@ inline absl::Status PreprocessAndFreezeGraph(mlir::ModuleOp module_op, /*mlir_dump_file_prefix=*/kDefaultTfQuantMlirDumpFilePrefix, /*is_inliner_run=*/true, /*noinline_functions=*/{}, module_op, context, session, /*run_tf_to_stablehlo=*/false, - /*deserialize_xla_call_module=*/false); + /*deserialize_xla_call_module=*/false, /*input_arg_shapes=*/{}); } -void AddTFToStablehloPasses(mlir::PassManager& pm); +// TF->StableHLO has limited support for dynamic shapes. +// Some models can only be converted with explicitly provided input argument +// shapes. +void AddTFToStablehloPasses( + mlir::PassManager& pm, + llvm::ArrayRef> input_arg_shapes = {}); } // namespace quantization } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_custom_aggregation_op_to_quant_stats.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_custom_aggregation_op_to_quant_stats.mlir index f72c9f3388c071..91e1e1d82d6150 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_custom_aggregation_op_to_quant_stats.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_custom_aggregation_op_to_quant_stats.mlir @@ -1,18 +1,18 @@ // RUN: tf-quant-opt %s -quant-convert-tf-custom-aggregator-op-to-quant-stats | FileCheck %s func.func @customAggregator(%arg0: tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>) { - %0:4 = "tf.CustomAggregator"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, id = "0"} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) - %1:4 = "tf.CustomAggregator"(%arg0) {id = "1"} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + %0:4 = "tf.CustomAggregator"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, id = "0", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + %1:4 = "tf.CustomAggregator"(%arg0) {id = "1", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) func.return %0#0, %1#0 : tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32> } // CHECK: func @customAggregator -// CHECK-NEXT: %[[stats:.*]] = "quantfork.stats"(%arg0) {layerStats = dense<[-1.000000e-01, 2.000000e-01]> : tensor<2xf32>} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> +// CHECK-NEXT: %[[stats:.*]] = "quantfork.stats"(%arg0) <{layerStats = dense<[-1.000000e-01, 2.000000e-01]> : tensor<2xf32>}> : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: return %[[stats]], %arg0 func.func @doNotHandleNoMinMaxCases(%arg0: tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>) { - %0:4 = "tf.CustomAggregator"(%arg0) {min = -0.1 : f32, id = "1"} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) - %1:4 = "tf.CustomAggregator"(%arg0) {max = 0.2 : f32, id = "2"} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) - %2:4 = "tf.CustomAggregator"(%arg0) {id = "3"} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + %0:4 = "tf.CustomAggregator"(%arg0) {min = -0.1 : f32, id = "1", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + %1:4 = "tf.CustomAggregator"(%arg0) {max = 0.2 : f32, id = "2", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + %2:4 = "tf.CustomAggregator"(%arg0) {id = "3", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) func.return %0#0, %1#0, %2#0 : tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32> } // CHECK: func @doNotHandleNoMinMaxCases diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir index 052da55dce336d..85480a46e352fe 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir @@ -26,10 +26,10 @@ module { // CalibrationOptions(calibration_method=CALIBRATION_METHOD_MIN_MAX) // MIN-MAX-CHECK: func @wrap_composite_func -// MIN-MAX-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) -// MIN-MAX-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) +// MIN-MAX-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{calibration_method = 1 : i32, id = "composite_conv2d_with_relu6_fn_arg_1_calibration_method_1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) +// MIN-MAX-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "composite_conv2d_with_relu6_fn_arg_0_calibration_method_1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) // MIN-MAX-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) -// MIN-MAX-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) +// MIN-MAX-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{calibration_method = 1 : i32, id = "composite_conv2d_with_relu6_fn_calibration_method_1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) // MIN-MAX-CHECK-NEXT: return [[res]] : tensor<*xf32> // MIN-MAX-CHECK: func @no_composite_func @@ -43,10 +43,10 @@ module { // CalibrationOptions(calibration_method=CALIBRATION_METHOD_AVERAGE_MIN_MAX) // AVERAGE-MIN-MAX-CHECK: func @wrap_composite_func -// AVERAGE-MIN-MAX-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) -// AVERAGE-MIN-MAX-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) +// AVERAGE-MIN-MAX-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{calibration_method = 2 : i32, id = "composite_conv2d_with_relu6_fn_arg_1_calibration_method_2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) +// AVERAGE-MIN-MAX-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 2 : i32, id = "composite_conv2d_with_relu6_fn_arg_0_calibration_method_2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) // AVERAGE-MIN-MAX-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) -// AVERAGE-MIN-MAX-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) +// AVERAGE-MIN-MAX-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{calibration_method = 2 : i32, id = "composite_conv2d_with_relu6_fn_calibration_method_2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) // AVERAGE-MIN-MAX-CHECK-NEXT: return [[res]] : tensor<*xf32> // AVERAGE-MIN-MAX-CHECK: func @no_composite_func @@ -60,13 +60,13 @@ module { // CalibrationOptions( // calibration_method=CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, -// calibration_parameters=CalibrationParameters(initial_num_bins=256, min_percentile=0.001, max_percentile=99.999) +// calibration_parameters=CalibrationParameters(num_bins=256, min_percentile=0.001, max_percentile=99.999) // ) // HISTOGRAM-PERCENTILE-CHECK: func @wrap_composite_func -// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) -// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{calibration_method = 3 : i32, id = "composite_conv2d_with_relu6_fn_arg_1_calibration_method_3", max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 3 : i32, id = "composite_conv2d_with_relu6_fn_arg_0_calibration_method_3", max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-PERCENTILE-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) -// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{calibration_method = 3 : i32, id = "composite_conv2d_with_relu6_fn_calibration_method_3", max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-PERCENTILE-CHECK-NEXT: return [[res]] : tensor<*xf32> // HISTOGRAM-PERCENTILE-CHECK: func @no_composite_func @@ -80,13 +80,13 @@ module { // CalibrationOptions( // calibration_method=CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, -// calibration_parameters=CalibrationParameters(initial_num_bins=256) +// calibration_parameters=CalibrationParameters(num_bins=256) // ) // HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @wrap_composite_func -// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) -// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{calibration_method = 4 : i32, id = "composite_conv2d_with_relu6_fn_arg_1_calibration_method_4", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 4 : i32, id = "composite_conv2d_with_relu6_fn_arg_0_calibration_method_4", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) -// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{calibration_method = 4 : i32, id = "composite_conv2d_with_relu6_fn_calibration_method_4", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: return [[res]] : tensor<*xf32> // HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @no_composite_func @@ -100,13 +100,13 @@ module { // CalibrationOptions( // calibration_method=CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, -// calibration_parameters=CalibrationParameters(initial_num_bins=256) +// calibration_parameters=CalibrationParameters(num_bins=256) // ) // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @wrap_composite_func -// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) -// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{calibration_method = 5 : i32, id = "composite_conv2d_with_relu6_fn_arg_1_calibration_method_5", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 5 : i32, id = "composite_conv2d_with_relu6_fn_arg_0_calibration_method_5", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) -// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{calibration_method = 5 : i32, id = "composite_conv2d_with_relu6_fn_calibration_method_5", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: return [[res]] : tensor<*xf32> // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @no_composite_func @@ -120,13 +120,13 @@ module { // CalibrationOptions( // calibration_method=CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, -// calibration_parameters=CalibrationParameters(initial_num_bins=256) +// calibration_parameters=CalibrationParameters(num_bins=256) // ) // HISTOGRAM-MSE-SYMMETRIC-CHECK: func @wrap_composite_func -// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) -// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{calibration_method = 6 : i32, id = "composite_conv2d_with_relu6_fn_arg_1_calibration_method_6", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 6 : i32, id = "composite_conv2d_with_relu6_fn_arg_0_calibration_method_6", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) -// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{calibration_method = 6 : i32, id = "composite_conv2d_with_relu6_fn_calibration_method_6", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: return [[res]] : tensor<*xf32> // HISTOGRAM-MSE-SYMMETRIC-CHECK: func @no_composite_func @@ -144,14 +144,17 @@ module { module { // CHECK-LABEL: func.func @main func.func @main(%arg0: tensor, %arg1: tensor<100352x10xf32>) -> tensor { - // CHECK-DAG: %[[ARG0_ID:.*]] = "tf.Identity"(%arg0) - // CHECK-DAG: %[[ARG1_ID:.*]] = "tf.Identity"(%arg1) - // CHECK-DAG: %[[ARG0_AGG:.*]] = "tf.CustomAggregator"(%[[ARG0_ID]]) - // CHECK-DAG: %[[ARG1_AGG:.*]] = "tf.CustomAggregator"(%[[ARG1_ID]]) - // CHECK: %[[RES:.*]] = "tf.XlaCallModule"(%[[ARG0_AGG]], %[[ARG1_AGG]]) - // CHECK: %[[RES_AGG:.*]] = "tf.CustomAggregator"(%[[RES]]) - // CHECK-DAG: %[[RES_ID:.*]] = "tf.Identity"(%[[RES_AGG]]) - // CHECK: return %[[RES_ID]] : tensor + // MIN-MAX-CHECK-DAG: %[[ARG0_ID:.*]] = "tf.Identity"(%arg0) + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG0_ID]]) + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_fn_1_arg_0_calibration_method_1" + // MIN-MAX-CHECK-DAG: %[[ARG1_ID:.*]] = "tf.Identity"(%arg1) + // MIN-MAX-CHECK: %[[ARG1_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG1_ID]]) + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_fn_1_arg_1_calibration_method_1" + // MIN-MAX-CHECK: %[[RES:.*]] = "tf.XlaCallModule"(%[[ARG0_AGG]], %[[ARG1_AGG]]) + // MIN-MAX-CHECK: %[[RES_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[RES]]) + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_fn_1_calibration_method_1" + // MIN-MAX-CHECK: %[[RES_ID:.*]] = "tf.Identity"(%[[RES_AGG]]) + // MIN-MAX-CHECK: return %[[RES_ID]] : tensor %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor %1 = "tf.Identity"(%arg1) {device = ""} : (tensor<100352x10xf32>) -> tensor<100352x10xf32> %2 = "tf.XlaCallModule"(%0, %1) <{ @@ -162,7 +165,8 @@ module { }> { _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", - _tfl_quant_trait = "fully_quantizable" + _tfl_quant_trait = "fully_quantizable", + _quantization_method = "static_range_ptq { }" } : (tensor, tensor<100352x10xf32>) -> tensor %3 = "tf.Identity"(%2) {device = ""} : (tensor) -> tensor return %3 : tensor @@ -175,3 +179,174 @@ module { return %0 : tensor } } + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1833 : i32}, tf_saved_model.semantics} { + func.func @serving_default(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi32>}> {device = ""} : () -> tensor<2xi32> + %cst_0 = "tf.Const"() <{value = dense<1.000000e+01> : tensor}> {device = ""} : () -> tensor + %0 = "tf.Sum"(%arg0, %cst) <{keep_dims = false}> {device = ""} : (tensor<1x4xf32>, tensor<2xi32>) -> tensor + %1 = "tf.Greater"(%0, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %2:2 = "tf.If"(%1, %arg0) <{else_branch = @cond_false_80, is_stateless = true, then_branch = @cond_true_70}> {Tcond = i1, Tin = [f32], Tout = [i1, f32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], device = ""} : (tensor, tensor<1x4xf32>) -> (tensor, tensor<1x3xf32>) + %3 = "tf.Identity"(%2#1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + + func.func private @cond_false_80(%arg0: tensor<1x4xf32> {tf._user_specified_name = "x"}) -> (tensor, tensor<1x3xf32>) attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x4>], tf._original_func_name = "cond_false_8"} { + %cst = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0.117216609, 0.933735609, 0.0728900209]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %0 = "tf.Identity"(%cst) {device = ""} : (tensor) -> tensor + %1 = "tf.PartitionedCall"(%arg0, %cst_1, %cst_0) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %0, %2 : tensor, tensor<1x3xf32> + } + // MIN-MAX-CHECK: func.func private @cond_false_80 + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_matmul_with_bias_fn_1_arg_0_calibration_method_1" + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_matmul_with_bias_fn_1_calibration_method_1" + + func.func private @cond_true_70(%arg0: tensor<1x4xf32> {tf._user_specified_name = "x"}) -> (tensor, tensor<1x3xf32>) attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x4>], tf._original_func_name = "cond_true_7"} { + %cst = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0.335351914, 0.084816426, -0.664676845]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %0 = "tf.Identity"(%cst) {device = ""} : (tensor) -> tensor + %1 = "tf.PartitionedCall"(%arg0, %cst_1, %cst_0) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %0, %2 : tensor, tensor<1x3xf32> + } + // MIN-MAX-CHECK: func.func private @cond_true_70 + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_matmul_with_bias_fn_2_arg_0_calibration_method_1" + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_matmul_with_bias_fn_2_calibration_method_1" + + func.func private @composite_matmul_with_bias_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_matmul_with_bias_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1833 : i32}, tf_saved_model.semantics} { + func.func @serving_default(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() <{value = dense<1.000000e+01> : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi32>}> {device = ""} : () -> tensor<2xi32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %cst_2 = "tf.Const"() <{value = dense<[0.335351914, 0.084816426, -0.664676845]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_3 = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_4 = "tf.Const"() <{value = dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %cst_5 = "tf.Const"() <{value = dense<[0.117216609, 0.933735609, 0.0728900209]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %0 = "tf.Sum"(%arg0, %cst_0) <{keep_dims = false}> {device = ""} : (tensor<1x4xf32>, tensor<2xi32>) -> tensor + %1 = "tf.Greater"(%0, %cst) {device = ""} : (tensor, tensor) -> tensor + %2:2 = "tf.IfRegion"(%1) <{_else_func_name = "cond_false_80", _then_func_name = "cond_true_70", is_stateless = true}> ({ + %4 = "tf.Identity"(%cst_3) {device = ""} : (tensor) -> tensor + %5 = "tf.PartitionedCall"(%arg0, %cst_1, %cst_2) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %6 = "tf.Identity"(%5) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + "tf.Yield"(%4, %6) {device = ""} : (tensor, tensor<1x3xf32>) -> () + }, { + %4 = "tf.Identity"(%cst_3) {device = ""} : (tensor) -> tensor + %5 = "tf.PartitionedCall"(%arg0, %cst_4, %cst_5) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %6 = "tf.Identity"(%5) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + "tf.Yield"(%4, %6) {device = ""} : (tensor, tensor<1x3xf32>) -> () + }) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = ""} : (tensor) -> (tensor, tensor<1x3xf32>) + %3 = "tf.Identity"(%2#1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + // MIN-MAX-CHECK: func.func @serving_default + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_matmul_with_bias_fn_1_arg_0_calibration_method_1" + // MIN-MAX-CHECK: "tf.IfRegion" + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_matmul_with_bias_fn_2_calibration_method_1" + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_matmul_with_bias_fn_1_calibration_method_1" + + func.func private @composite_matmul_with_bias_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + func.func private @composite_matmul_with_bias_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1836 : i32}, tf_saved_model.semantics} { + func.func @main(%arg0: tensor<10x1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<10x1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = stablehlo.constant dense<0.000000e+00>: tensor<10x1024x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) <{Sout = [#tf_type.shape<10x1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + return %0 : tensor<10x1x3xf32> + } + // MIN-MAX-CHECK: func.func @main + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_with_relu_fn_1_arg_0_calibration_method_1" + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_with_relu_fn_1_calibration_method_1" + + func.func private @composite_dot_general_with_relu_fn_1(%arg0: tensor<10x1x1024xf32>, %arg1: tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %cst = stablehlo.constant dense<0.000000e+00> : tensor<10x1x3xf32> + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %1 = stablehlo.maximum %0, %cst : tensor<10x1x3xf32> + return %1 : tensor<10x1x3xf32> + } +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1836 : i32}, tf_saved_model.semantics} { + func.func @main(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = stablehlo.constant dense<1.000000e+01> : tensor + %cst_0 = stablehlo.constant dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32> + %c = stablehlo.constant dense : tensor + %cst_1 = stablehlo.constant dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32> + %cst_2 = stablehlo.constant dense<-0.000000e+00> : tensor + %cst_3 = stablehlo.constant dense<[[0.335351914, 0.084816426, -0.664676845]]> : tensor<1x3xf32> + %cst_4 = stablehlo.constant dense<[[0.117216609, 0.933735609, 0.0728900209]]> : tensor<1x3xf32> + %0 = stablehlo.reduce(%arg0 init: %cst_2) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x4xf32>, tensor) -> tensor + %1 = stablehlo.compare GT, %0, %cst : (tensor, tensor) -> tensor + %2:2 = "stablehlo.if"(%1) ({ + %3 = "tf.XlaCallModule"(%arg0, %cst_0, %cst_3) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_same_shape_fn_2, _original_entry_function = "composite_dot_general_with_bias_same_shape_fn_2", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + stablehlo.return %c, %3 : tensor, tensor<1x3xf32> + }, { + %3 = "tf.XlaCallModule"(%arg0, %cst_1, %cst_4) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_same_shape_fn_1, _original_entry_function = "composite_dot_general_with_bias_same_shape_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + stablehlo.return %c, %3 : tensor, tensor<1x3xf32> + }) : (tensor) -> (tensor, tensor<1x3xf32>) + return %2#1 : tensor<1x3xf32> + } + // MIN-MAX-CHECK: func.func @main + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_with_bias_same_shape_fn_1_arg_0_calibration_method_1" + // MIN-MAX-CHECK: "stablehlo.if" + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_with_bias_same_shape_fn_2_calibration_method_1" + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_with_bias_same_shape_fn_1_calibration_method_1" + + func.func private @composite_dot_general_with_bias_same_shape_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_with_bias_same_shape_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/issue_ids_of_custom_aggregation_ops.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/issue_ids_of_custom_aggregation_ops.mlir deleted file mode 100644 index 6a1621cdf17e89..00000000000000 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/issue_ids_of_custom_aggregation_ops.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: tf-quant-opt %s -quant-issues-ids-of-custom-aggregation-ops | FileCheck %s - -func.func @issue_ids(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0:4 = "tf.CustomAggregator"(%arg1) {id = ""} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<*xi64>) - %1:4 = "tf.CustomAggregator"(%arg0) {id = ""} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<*xi64>) - %2 = "tf.AddV2"(%1, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %3:4 = "tf.CustomAggregator"(%2) {id = ""} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<*xi64>) - func.return %3 : tensor<*xf32> -} - - -// CHECK: func @issue_ids -// CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{id = "0"}> : (tensor<*xf32>) -// CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "1"}> : (tensor<*xf32>) -// CHECK-NEXT: [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> -// CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{id = "2"}> : (tensor<*xf32>) -// CHECK-NEXT: return [[res]] : tensor<*xf32> diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions.mlir index 6a7f9da6bc5563..b0ce385ba41628 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions.mlir @@ -421,3 +421,88 @@ func.func @conv3d_with_bias(%arg0: tensor<1x3x4x3x3xf32>) -> (tensor<1x3x2x3x2xf // CHECK-LABEL: private @composite_conv3d_with_bias_and_relu6_fn_1 // CHECK-LABEL: private @composite_conv3d_with_bias_fn_1 } + +// ----- + +// Test that the name of composite functions are deterministic. There are 3 +// unsorted functions in this module and each function has 2 quantizable ops. +module { + func.func @float_conv_3(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + + %3 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %5 = "tf.Relu6"(%4) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + func.return %2, %5 : tensor<*xf32>, tensor<*xf32> + } + + func.func @float_conv_1(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + + %3 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %5 = "tf.Relu6"(%4) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + func.return %2, %5 : tensor<*xf32>, tensor<*xf32> + } + + func.func @float_conv_2(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + + %3 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %5 = "tf.Relu6"(%4) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + func.return %2, %5 : tensor<*xf32>, tensor<*xf32> + } +} + +// CHECK-LABEL: @float_conv_3 +// CHECK: "tf.PartitionedCall" +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_6 +// CHECK: "tf.PartitionedCall" +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_5 + +// CHECK-LABEL: @float_conv_1 +// CHECK: "tf.PartitionedCall" +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_2 +// CHECK: "tf.PartitionedCall" +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_1 + +// CHECK-LABEL: @float_conv_2 +// CHECK: "tf.PartitionedCall" +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_4 +// CHECK: "tf.PartitionedCall" +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_3 + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir index 0f3c7024dba4b4..7e020bd279d2b6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir @@ -21,7 +21,7 @@ func.func private @conv(%input: tensor<1x3x4x3xf32> {tf._user_specified_name = " // CHECK-DAG: [[bias:%.+]] = "arith.constant"() <{value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>}> : () -> tensor<2xf32> // CHECK-DAG: [[weight:%.+]] = "arith.constant"() <{value = dense_resource<__elided__> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2x!quant.uniform> -// CHECK: [[q_input:%.+]] = "quantfork.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: [[q_input:%.+]] = "quantfork.qcast"([[ARG0:%arg[0-9]+]]) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> // CHECK-NEXT: [[q_bias:%.+]] = "quantfork.qcast"([[bias]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform> // CHECK-NEXT: [[conv:%.+]] = "tf.PartitionedCall"([[q_input]], [[weight]], [[q_bias]]) <{config = "", config_proto = "", executor_type = "", f = @[[composite_fn:composite_conv2d_with_bias_and_relu6_fn.*]]}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<*x!quant.uniform> // CHECK-NEXT: [[res:%.+]] = "quantfork.dcast"([[conv]]) : (tensor<*x!quant.uniform>) -> tensor<*xf32> diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir index e3bda3f5d09af9..b8b31d880e0c78 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir @@ -15,7 +15,7 @@ module { // CHECK: %[[cst:.*]] = "arith.constant"() <{value = dense<0.000000e+00> : tensor<2x1024xf32>}> : () -> tensor<2x1024xf32> // CHECK: %[[q_cst:.*]] = "quantfork.qcast"(%[[cst]]) : (tensor<2x1024xf32>) -> tensor<2x1024x!quant.uniform:f32, 3.9370078740157481E-9>> -// CHECK: %[[out:.*]] = "tf.PartitionedCall"(%arg0, %[[q_cst]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x2x2x3xf32>, tensor<2x1024x!quant.uniform:f32, 3.9370078740157481E-9>>) -> tensor<*xf32> +// CHECK: %[[out:.*]] = "tf.PartitionedCall"([[ARG0:%arg[0-9]+]], %[[q_cst]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x2x2x3xf32>, tensor<2x1024x!quant.uniform:f32, 3.9370078740157481E-9>>) -> tensor<*xf32> // CHECK: "func.return"(%[[out]]) : (tensor<*xf32>) -> () } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir index f24b6399774f08..6a6c176ad37d2e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir @@ -21,7 +21,7 @@ func.func private @conv(%input: tensor<1x3x4x3xf32> {tf._user_specified_name = " // CHECK-DAG: [[bias:%.+]] = "arith.constant"() <{value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>}> : () -> tensor<2xf32> // CHECK-DAG: [[weight:%.+]] = "arith.constant"() <{value = dense_resource<__elided__> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2x!quant.uniform> -// CHECK: [[q_input:%.+]] = "quantfork.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: [[q_input:%.+]] = "quantfork.qcast"([[ARG0:%arg[0-9]+]]) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> // CHECK-NEXT: [[q_bias:%.+]] = "quantfork.qcast"([[bias]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform> // CHECK-NEXT: [[conv:%.+]] = "tf.PartitionedCall"([[q_input]], [[weight]], [[q_bias]]) <{config = "", config_proto = "", executor_type = "", f = @[[composite_fn:composite_conv2d_with_bias_and_relu6_fn.*]]}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<*x!quant.uniform> // CHECK-NEXT: [[res:%.+]] = "quantfork.dcast"([[conv]]) : (tensor<*x!quant.uniform>) -> tensor<*xf32> diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD index 21600fb78083a5..4397b4fc5a3f2d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD @@ -30,6 +30,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -85,6 +86,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_xla//xla:xla_data_proto_cc", ], ) @@ -100,5 +102,6 @@ tf_cc_test( "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h index 5a1734bf6bf026..702e19506d2fd6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h @@ -112,7 +112,7 @@ class ConvertFakeQuantOpToQuantOps { Value input = tf_op.getInputs(); int quant_dim = -1; - auto input_type = input.getType().template cast(); + auto input_type = mlir::cast(input.getType()); if (PerAxis) { if (!input_type.hasRank()) { tf_op.emitError("The input should have known rank for per-channel op."); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_quantize_op_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_quantize_op_utils.cc index 264c6c508a60f7..1392bf4de2a92f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_quantize_op_utils.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_quantize_op_utils.cc @@ -16,14 +16,15 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace quant { UnrankedTensorType CreateUnknownShapeFromElementType(Type tensor_type) { - if (!tensor_type.cast()) return UnrankedTensorType(); + if (!mlir::cast(tensor_type)) return UnrankedTensorType(); return UnrankedTensorType::get( - tensor_type.cast().getElementType()); + mlir::cast(tensor_type).getElementType()); } } // namespace quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc index 967af993c0bcf7..430d5ff6ba2047 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.h" @@ -66,9 +67,9 @@ constexpr std::array kSuffixes = {"_min_val", "_max_val"}; Attribute GetWindowStridesValue( PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { - ArrayAttr stride = identifier_to_attr["strides"].dyn_cast(); - const int stride_h = stride[1].cast().getInt(); - const int stride_w = stride[2].cast().getInt(); + ArrayAttr stride = mlir::dyn_cast(identifier_to_attr["strides"]); + const int stride_h = mlir::cast(stride[1]).getInt(); + const int stride_w = mlir::cast(stride[2]).getInt(); return rewriter.getI64ArrayAttr({stride_h, stride_w}); } @@ -79,23 +80,24 @@ Attribute GetLhsDilationValue(PatternRewriter& rewriter, Attribute GetRhsDilationValue(PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { - ArrayAttr dilations = identifier_to_attr["dilations"].dyn_cast(); - const int dilation_h = dilations[1].cast().getInt(); - const int dilation_w = dilations[2].cast().getInt(); + ArrayAttr dilations = + mlir::dyn_cast(identifier_to_attr["dilations"]); + const int dilation_h = mlir::cast(dilations[1]).getInt(); + const int dilation_w = mlir::cast(dilations[2]).getInt(); return rewriter.getI64ArrayAttr({dilation_h, dilation_w}); } Attribute GetPaddingValue(PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { llvm::StringRef padding = - identifier_to_attr["padding"].dyn_cast().getValue(); + mlir::dyn_cast(identifier_to_attr["padding"]).getValue(); return rewriter.getStringAttr(padding); } Attribute GetExplicitPaddingValue( PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { ArrayAttr explicit_padding = - identifier_to_attr["explicit_paddings"].dyn_cast(); + mlir::dyn_cast(identifier_to_attr["explicit_paddings"]); return explicit_padding; } @@ -167,7 +169,7 @@ LogicalResult CheckIfAttrIs8Bit(const std::string& attr, Operation* op, element_type = getElementTypeOrSelf(op->getOpResult(0).getType()); } if (element_type) { - is_8_bit = element_type.isa(); + is_8_bit = mlir::isa(element_type); return success(); } return failure(); @@ -295,7 +297,8 @@ LogicalResult FillAttributesForUniformQuantizedConvolutionOp( auto feature_group_cnt_attr = llvm::StringRef("feature_group_count"); int feature_group_cnt = 1; - ShapedType input_shape = op->getOperand(0).getType().dyn_cast(); + ShapedType input_shape = + mlir::dyn_cast(op->getOperand(0).getType()); if (!input_shape) { return op->emitError( "Only input with known shape is supported for Uniform Quantized " @@ -425,7 +428,8 @@ LogicalResult FillAttributesForUniformRequantizeOp( activation_quantization_axis = GetQuantizationAxis(rewriter, op, /*operand_index=*/0); - auto output_scale_type = op->getOperand(3).getType().dyn_cast(); + auto output_scale_type = + mlir::dyn_cast(op->getOperand(3).getType()); if (!output_scale_type) { return failure(); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc index f1d7a6ae576c7b..b22726de30aeaa 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_format.h" #include "llvm/ADT/ArrayRef.h" +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.h" #include "xla/xla_data.pb.h" @@ -34,8 +35,7 @@ Value GetDimValue(OpBuilder &builder, Location loc, Value shape_value, return builder.create( loc, RankedTensorType::get( - {}, - shape_value.getType().template cast().getElementType()), + {}, mlir::cast(shape_value.getType()).getElementType()), /*input=*/shape_value, /*begin=*/Create1DConstValue(builder, loc, {dim}), /*end=*/Create1DConstValue(builder, loc, {dim + 1}), @@ -109,14 +109,14 @@ Value PadForDynamicShapedInputSamePadding( CreateConstValue(builder, loc, {rank}, shape)); }; - ShapedType filter_shape = filter.getType().template cast(); + ShapedType filter_shape = mlir::cast(filter.getType()); Value input_shape_value = builder.create( loc, RankedTensorType::get({num_dims}, builder.getI32Type()), input); auto scalar_to_rank1 = [&](Value value) { return reshape_op(value, {1}); }; for (int i : llvm::seq(1, num_dims - 1)) { Value input_size_i = GetDimValue(builder, loc, input_shape_value, i); - const int stride_i = strides[i].cast().getInt(); - const int dilation_i = dilations[i].cast().getInt(); + const int stride_i = mlir::cast(strides[i]).getInt(); + const int dilation_i = mlir::cast(dilations[i]).getInt(); const int filter_i = filter_shape.getDimSize(i - 1); Value pad_i_low, pad_i_high; GetSamePaddingValues(builder, loc, input_size_i, filter_i, dilation_i, @@ -154,21 +154,21 @@ Value CalculatePaddingAndPadIfNeeded(OpBuilder &builder, Location loc, StringAttr conv_padding, ArrayAttr explicit_paddings, Value &padding, int num_dims) { - ShapedType input_shape = input.getType().template cast(); + ShapedType input_shape = mlir::cast(input.getType()); SmallVector spatial_dims(num_dims - 2); absl::c_iota(spatial_dims, 1); bool has_dynamic_spatial_dim = absl::c_any_of( spatial_dims, [&input_shape](int64_t dim) { return input_shape.isDynamicDim(dim); }); - if (conv_padding.strref().equals("SAME") && has_dynamic_spatial_dim) { + if (conv_padding.strref() == "SAME" && has_dynamic_spatial_dim) { return PadForDynamicShapedInputSamePadding( builder, loc, input, filter, input_zp_value, strides, dilations, conv_padding, padding, num_dims); } - ShapedType filter_shape = filter.getType().template cast(); + ShapedType filter_shape = mlir::cast(filter.getType()); SmallVector padding_values(2 * num_dims, 0); - if (conv_padding.strref().equals("EXPLICIT")) { + if (conv_padding.strref() == "EXPLICIT") { if (explicit_paddings.size() != 2 * num_dims) { emitError(loc, absl::StrFormat( @@ -178,16 +178,16 @@ Value CalculatePaddingAndPadIfNeeded(OpBuilder &builder, Location loc, } for (int i : spatial_dims) { padding_values[2 * i] = - explicit_paddings[2 * i].cast().getInt(); + mlir::cast(explicit_paddings[2 * i]).getInt(); padding_values[2 * i + 1] = - explicit_paddings[2 * i + 1].cast().getInt(); + mlir::cast(explicit_paddings[2 * i + 1]).getInt(); } - } else if (conv_padding.strref().equals("SAME")) { + } else if (conv_padding.strref() == "SAME") { for (int i : spatial_dims) { int input_size = input_shape.getDimSize(i); int filter_size = filter_shape.getDimSize(i - 1); - int stride_i = strides[i].cast().getInt(); - int dilation_i = dilations[i].cast().getInt(); + int stride_i = mlir::cast(strides[i]).getInt(); + int dilation_i = mlir::cast(dilations[i]).getInt(); int out_size = tflite::ComputeOutSize(kTfLitePaddingSame, input_size, filter_size, stride_i, dilation_i); @@ -243,7 +243,7 @@ Value CalculatePaddingAndPadIfNeeded(OpBuilder &builder, Location loc, // // packed_value = bitwise_or(packed_low, packed_high) Value PackOperand(OpBuilder &builder, Location loc, Value value, int pack_dim) { - ShapedType value_type = value.getType().cast(); + ShapedType value_type = mlir::cast(value.getType()); const int rank = value_type.getRank(); SmallVector packed_shape(value_type.getShape().begin(), diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils_test.cc index cc4bbb344026da..cbcda677b87733 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" @@ -51,7 +52,8 @@ void PackOperandTestHelper( DenseIntElementsAttr packed_value_attr; ASSERT_TRUE(matchPattern(packed_value, m_Constant(&packed_value_attr))); - ShapedType packed_shape_type = packed_value.getType().dyn_cast(); + ShapedType packed_shape_type = + mlir::dyn_cast(packed_value.getType()); llvm::SmallVector packed_shape(packed_shape_type.getShape().begin(), packed_shape_type.getShape().end()); EXPECT_THAT(packed_shape, testing::ElementsAreArray(expected_packed_shape)); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/BUILD new file mode 100644 index 00000000000000..ecc3c9e7ca6ebe --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/BUILD @@ -0,0 +1,103 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_portable") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allowlist.bzl", "internal_visibility_allowlist") + +package_group( + name = "internal_visibility_allowlist_package", + packages = [ + "//tensorflow/compiler/mlir/lite/...", + "//tensorflow/compiler/mlir/quantization/...", + "//tensorflow/compiler/mlir/tf2xla/transforms/...", + "//tensorflow/lite/...", + ] + internal_visibility_allowlist(), +) + +package( + # copybara:uncomment default_applicable_licenses = ["@stablehlo//:license"], + default_visibility = [ + ":internal_visibility_allowlist_package", + "//tensorflow:__pkg__", + ], + licenses = ["notice"], +) + +cc_library( + name = "tf_to_stablehlo", + srcs = [ + "tf_to_stablehlo.cc", + ], + hdrs = [ + "tf_to_stablehlo.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:saved_model_import", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", + "//tensorflow/compiler/mlir/tensorflow/transforms:shape_inference_pass", + "//tensorflow/core:core_cpu_base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], + alwayslink = True, +) + +tf_cc_binary( + name = "tf-to-stablehlo-translate", + srcs = [ + "tf_to_stablehlo_translate.cc", + ], + visibility = [":internal_visibility_allowlist_package"], + deps = [ + ":tf_to_stablehlo", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + +glob_lit_tests( + name = "all_tests", + data = [":test_utilities"], + default_tags = [ + "no_oss", + "no_pip", + ], + driver = "//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo:run_lit.sh", + size_override = { + }, + tags_override = { + }, + test_file_exts = ["mlir"], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + ":tf-to-stablehlo-translate", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", + "@llvm-project//mlir:run_lit.sh", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/README.md b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/README.md new file mode 100644 index 00000000000000..a65de3c38df001 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/README.md @@ -0,0 +1,123 @@ +# Tensorflow SavedModel to StableHLO (tf-to-stablehlo-translate) + +Converts TensorFlow models (SavedModel or MLIR module) to StableHLO MLIR +modules, preserving model structure and signatures. It enables seamless +integration of TensorFlow models into MLIR-based compiler frameworks for further +optimization and deployment. + +## C++ APIs + +```bash +tf-to-stablehlo-translate \ + --input-path=/path/to/model \ + [--exported-model-signatures=signature1,signature2] \ + [--tag-names=tag1,tag2] \ + [--input-arg-shapes-str=arg-name:shape,...] \ + [--e] \ + [--output-filename=/path/to/output.mlir] +``` + +* `--input-path`: The path to the input TensorFlow SavedModel or MLIR module + with .mlir extension. +* `--exported-model-signatures`: Comma-separated list of exported model + signatures to convert. Ignored for MLIR input. +* `--tags`: Comma-separated list of tags for loading SavedModel. Ignored for + MLIR input. +* `--input-arg-shapes`: A string representation of input argument shapes for + 'main' entry-point, separating tensors with ':', dimension with ',', and + using '?' for unknown sizes. For example, `input-arg-shapes=1,2::1,?` + expresses argument shapes `[1,2]`, `[]` and `[1,?]`. +* `--e`: Elide large elements attrs while dumping the output StableHLO. +* `--output_filename`: Path to the output file where the textual StableHLO MLIR + module will be written (default: stdout). + + +### Examples + +* To convert [microsoft/resnet-50](https://huggingface.co/microsoft/resnet-50) +model to StableHLO with static input shape `4x3x224x224` for input argument with +type `tensor`. + +```bash +tf-to-stablehlo-translate --input-arg-shapes=4,3,224,224 +``` + +* To convert +[google-bert/bert-large-uncased](https://huggingface.co/google-bert/bert-large-uncased) +to StableHLO with static input shapes `1x12`, `1x12`, and `1x12` for input +arguments with types `tensor, tensor, tensor`. + +```bash +tf-to-stablehlo-translate --input-arg-shapes=1,12:1,12:1,12 +``` + +### Dependencies + +* TensorFlow +* MLIR +* Abseil (absl) + +## Python APIs + + +### `savedmodel_to_stablehlo` + +Converts a TensorFlow SavedModel into StableHLO bytecode. + +```Python +from tensorflow.compiler.mlir.quantization.tensorflow_to_stablehlo.python import pywrap_tensorflow_to_stablehlo as tf2shlo + +stablehlo_bytes = tf2shlo.savedmodel_to_stablehlo( + input_path="/path/to/your/savedmodel", + exported_model_signatures=["serving_default"], + tag_names=["serve"], + input_arg_shapes_str="1,28,28,3::32" +) + +``` + +#### Arguments: + +* `input_path` (required): Path to your SavedModel directory. +* `exported_model_signatures` (optional): List of signature names to convert. + Defaults to ["serving_default"]. +* `tag_names` (optional): List of tags associated with the SavedModel. Defaults + to ["serve"]. +* `input_arg_shapes_str` (optional): A string representation of input argument + shapes for 'main' entry-point, separating + tensors with ':', dimension with ',', and + using '?' for unknown sizes. For example, + `input-arg-shapes=1,2::1,?` expresses + argument shapes `[1,2], [] and [1,?]`. + +#### Error Handling + +An exception will be raised with details about the error. + +### `tensorflow_module_to_stablehlo` + +Converts a TensorFlow MLIR module string into StableHLO bytecode. + +```Python +from tensorflow.compiler.mlir.quantization.tensorflow_to_stablehlo.python import pywrap_tensorflow_to_stablehlo as tf2shlo + +stablehlo_bytes = tf2shlo.tensorflow_module_to_stablehlo( + module_op_str="your_tensorflow_mlir_module_string", + input_arg_shapes_str="1,28,28,3::32" +) +``` + +#### Arguments: + +* `module_op_str` (required): String containing the TensorFlow MLIR module. +* `input_arg_shapes_str` (optional): A string representation of input argument + shapes for 'main' entry-point, separating + tensors with ':', dimension with ',', and + using '?' for unknown sizes. For example, + `input-arg-shapes=1,2::1,?` expresses + argument shapes `[1,2], [] and [1,?]`. + +#### Error Handling + +Return `py::none()` (equivalent to Python's `None`) if there's an error. An +exception will be raised with details about the error. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/BUILD new file mode 100644 index 00000000000000..f7a1c77026d215 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/BUILD @@ -0,0 +1,108 @@ +load( + "//tensorflow:tensorflow.default.bzl", + "get_compatible_with_portable", + "tf_py_strict_test", + "tf_python_pybind_extension", +) +load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allowlist.bzl", "internal_visibility_allowlist") + +package_group( + name = "internal_visibility_allowlist_package", + packages = [ + "//tensorflow/compiler/mlir/lite/...", + "//tensorflow/compiler/mlir/quantization/...", + "//tensorflow/compiler/mlir/tf2xla/transforms/...", + "//tensorflow/lite/...", + ] + internal_visibility_allowlist(), +) + +package( + # copybara:uncomment default_applicable_licenses = ["@stablehlo//:license"], + default_visibility = [ + ":internal_visibility_allowlist_package", + "//tensorflow:__pkg__", + "//tensorflow/python:__pkg__", + ], + licenses = ["notice"], +) + +# copybara:uncomment_begin(google-only) +# tf_py_strict_test( +# name = "tensorflow_to_stablehlo_test", +# testonly = 1, +# srcs = ["integration_test/tensorflow_to_stablehlo_test.py"], +# deps = [ +# ":pywrap_tensorflow_to_stablehlo", +# "//testing/pymocks:matchers", +# "//third_party/py/mlir", +# "//third_party/py/mlir:ir", +# "//third_party/py/mlir:stablehlo_dialect", +# "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", +# "//tensorflow:tensorflow_py", +# "//tensorflow/compiler/mlir/stablehlo", +# "//tensorflow/python/framework:test_lib", +# "//tensorflow/python/platform:client_testlib", +# "//tensorflow/python/types:core", +# ], +# ) +# copybara:uncomment_end + +# This is a header-only target. The purpose of `pywrap_tensorflow_to_stablehlo_lib_*` targets is to expose only +# the symbols that are required by `pywrap_tensorflow_to_stablehlo` that translates them to python functions. +# The only intended use case of this library is by `pywrap_tensorflow_to_stablehlo`. Not letting +# `pywrap_tensorflow_to_stablehlo` directly depend on sub-libraries like `static_range_srq` and instead haiving +# a consolidated impl library `pywrap_tensorflow_to_stablehlo_lib_impl` allows the maintainers to avoid +# declaring multiple impl libraries to `libtensorflow_cc` and `lib_pywrap_tensorflow_internal`, +# which is required to avoid ODR violations. +cc_library( + name = "pywrap_tensorflow_to_stablehlo_lib_header_only", + srcs = [], + hdrs = ["pywrap_tensorflow_to_stablehlo_lib.h"], + compatible_with = get_compatible_with_portable(), + visibility = ["//visibility:private"], # ONLY for `pywrap_tensorflow_to_stablehlo`. + deps = [ + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +# See the comments for `pywrap_tensorflow_to_stablehlo_lib_header_only`. +cc_library( + name = "pywrap_tensorflow_to_stablehlo_lib_impl", + srcs = ["pywrap_tensorflow_to_stablehlo_lib.cc"], + hdrs = ["pywrap_tensorflow_to_stablehlo_lib.h"], + compatible_with = get_compatible_with_portable(), + visibility = [ + "//tensorflow:__pkg__", # For libtensorflow_cc.so. + "//tensorflow/python:__pkg__", # For lib_pywrap_tensorflow_internal.so. + ], + deps = [ + "//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo:tf_to_stablehlo", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:lib", + "//third_party/python_runtime:headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + ], +) + +tf_python_pybind_extension( + name = "pywrap_tensorflow_to_stablehlo", + srcs = ["pywrap_tensorflow_to_stablehlo.cc"], + pytype_srcs = ["pywrap_tensorflow_to_stablehlo.pyi"], + # Each dependency MUST be either header-only or exclusive. + deps = [ + ":pywrap_tensorflow_to_stablehlo_lib_header_only", + "//third_party/python_runtime:headers", + "@pybind11", + "@pybind11_abseil//pybind11_abseil:absl_casters", + "@pybind11_abseil//pybind11_abseil:status_casters", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/integration_test/tensorflow_to_stablehlo_test.py b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/integration_test/tensorflow_to_stablehlo_test.py new file mode 100644 index 00000000000000..28b224abdd5db9 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/integration_test/tensorflow_to_stablehlo_test.py @@ -0,0 +1,78 @@ +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tempfile +from mlir import ir +from mlir.dialects import stablehlo +import tensorflow as tf +from tensorflow.compiler.mlir.quantization.tensorflow_to_stablehlo.python import pywrap_tensorflow_to_stablehlo as tensorflow_to_stablehlo +from tensorflow.python.platform import test + + +def build_savedmodel(tempdir) -> str: + + class AddOneModel(tf.keras.Model): + + def call(self, x): + return x + 1 + + model = AddOneModel() + + x_train = tf.constant([1, 2, 3, 4, 5], dtype=tf.float32) + y_train = tf.constant([2, 3, 4, 5, 6], dtype=tf.float32) + + model.compile(optimizer='sgd', loss='mse') + model.fit(x_train, y_train, epochs=1) + + path = tempdir + '/add_one_model' + model.save(path) + return path + + +class TensorflowToStableHLOTest(test.TestCase): + + def test_saved_model_to_stablehlo(self): + with tempfile.TemporaryDirectory() as tempdir: + path = build_savedmodel(tempdir) + module_bytecode = tensorflow_to_stablehlo.savedmodel_to_stablehlo( + input_path=path, input_arg_shapes_str='4' + ) + with ir.Context() as ctx: + stablehlo.register_dialect(ctx) + module = ir.Module.parse(module_bytecode) + self.assertIn('stablehlo.add %arg0, %cst : tensor<4xf32>', str(module)) + + def test_tf_mlir_to_stablehlo(self): + assembly = """ + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main(%arg0 : tensor) -> tensor { + %cst = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + %0 = "tf.Add"(%arg0, %cst): (tensor, tensor) -> tensor + func.return %0 : tensor + } + } + """ + module_bytecode = tensorflow_to_stablehlo.tensorflow_module_to_stablehlo( + module=assembly, + input_arg_shapes_str='4', + ) + with ir.Context() as ctx: + stablehlo.register_dialect(ctx) + module = ir.Module.parse(module_bytecode) + self.assertIn('stablehlo.add %arg0, %cst : tensor<4xf32>', str(module)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo.cc b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo.cc new file mode 100644 index 00000000000000..1d1f775f5dfda1 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo.cc @@ -0,0 +1,97 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "pybind11/pybind11.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 +#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // IWYU pragma: keep +#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.h" + +namespace py = pybind11; + +namespace { + +using mlir::tensorflow_to_stablehlo::pywrap::PywrapSavedModelToStablehlo; +using mlir::tensorflow_to_stablehlo::pywrap::PywrapTfModuleToStablehlo; + +} // namespace + +PYBIND11_MODULE(pywrap_tensorflow_to_stablehlo, m) { + m.doc() = "TensorFlow to StableHLO APIs."; + + // LINT.IfChange(savedmodel_to_stablehlo) + m.def( + "savedmodel_to_stablehlo", + [](absl::string_view input_path, + const std::vector& exported_model_signatures = + {"serving_default"}, + const std::vector& tag_names = {"serve"}, + absl::string_view input_arg_shapes_str = "") -> py::bytes { + auto module_bytecode = + PywrapSavedModelToStablehlo(input_path, exported_model_signatures, + tag_names, input_arg_shapes_str); + if (!module_bytecode.ok()) { + PyErr_SetString(PyExc_ValueError, + module_bytecode.status().ToString().c_str()); + throw py::error_already_set(); + } + return py::bytes(module_bytecode.value()); + }, + R"pbdoc( + Converts a TensorFlow SavedModel into StableHLO bytecode. + + * input-path: The path to the input TensorFlow SavedModel. + * exported-model-signatures: Comma-separated list of exported model + signatures to convert. + * tag_names: Comma-separated list of tags for loading SavedModel. + * input-arg-shapes: A string representation of input argument shapes for + 'main' entry-point, separating tensors with ':', dimension with ',', and + using '?' for unknown sizes. For example, 'input-arg-shapes=1,2::1,?' + expresses argument shapes [1,2], [] and [1,?]. + )pbdoc", + py::arg("input_path"), + py::arg("exported_model_signatures") = + std::vector{"serving_default"}, + py::arg("tag_names") = std::vector{"serve"}, + py::arg("input_arg_shapes_str") = ""); + // LINT.ThenChange(pywrap_tensorflow_to_stablehlo.pyi:savedmodel_to_stablehlo) + // + // LINT.IfChange(tensorflow_module_to_stablehlo) + m.def( + "tensorflow_module_to_stablehlo", + [](absl::string_view module_op_str, + absl::string_view input_arg_shapes_str) -> py::bytes { + auto module_bytecode = + PywrapTfModuleToStablehlo(module_op_str, input_arg_shapes_str); + if (!module_bytecode.ok()) { + PyErr_SetString(PyExc_ValueError, + module_bytecode.status().ToString().c_str()); + throw py::error_already_set(); + } + return py::bytes(module_bytecode.value()); + }, + R"pbdoc( + Converts a TensorFlow MLIR module string into StableHLO bytecode. + + * module: TensorFlow MLIR module string. + * input-arg-shapes: A string representation of input argument shapes for + 'main' entry-point, separating tensors with ':', dimension with ',', and + using '?' for unknown sizes. For example, 'input-arg-shapes=1,2::1,?' + expresses argument shapes [1,2], [] and [1,?]. + )pbdoc", + py::arg("module"), py::arg("input_arg_shapes_str") = ""); + // LINT.ThenChange(pywrap_tensorflow_to_stablehlo.pyi:tensorflow_module_to_stablehlo) +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo.pyi b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo.pyi new file mode 100644 index 00000000000000..ec5eaad7983bf0 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo.pyi @@ -0,0 +1,30 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# LINT.IfChange(savedmodel_to_stablehlo) +def savedmodel_to_stablehlo( + input_path: str, + exported_model_signatures: list[str] = ["serving_default"], + tag_names: list[str] = ["serve"], + input_arg_shapes_str: str = "", +) -> bytes: ... +# LINT.ThenChange() + +# LINT.IfChange(tensorflow_module_to_stablehlo) +def tensorflow_module_to_stablehlo( + module: str, + input_arg_shapes_str: str = "", +) -> bytes: ... +# LINT.ThenChange() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.cc b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.cc new file mode 100644 index 00000000000000..cbd535a861482f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.cc @@ -0,0 +1,141 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/core/platform/path.h" + +namespace mlir::tensorflow_to_stablehlo::pywrap { + +absl::StatusOr ModuleToBytecode(ModuleOp module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + +absl::StatusOr ExportModule(ModuleOp module) { + const std::string output_filename = tensorflow::io::GetTempFilename(".mlir"); + std::string error_msg; + auto output = openOutputFile(output_filename, &error_msg); + if (output == nullptr) { + return absl::UnknownError( + absl::StrCat("Unable to open output path: ", error_msg)); + } + + std::string result; + llvm::raw_string_ostream os(result); + OpPrintingFlags printing_flags; + module.print(os, printing_flags); + + output->os() << result; + output->keep(); + + return output_filename; +} + +absl::StatusOr PywrapSavedModelToStablehlo( + absl::string_view input_path, + const std::vector& exported_model_signatures, + const std::vector& tag_names, + absl::string_view input_arg_shapes_str) { + mlir::DialectRegistry registry; + RegisterAllTensorFlowDialects(registry); + mlir::MLIRContext context(registry); + context.loadAllAvailableDialects(); + + auto module = + TfToStablehlo(input_path, &context, exported_model_signatures, tag_names, + input_arg_shapes_str, /*is_input_mlir_module=*/false); + + if (!module.ok()) { + return absl::UnknownError( + absl::StrCat("Failed to convert SavedModel to StableHLO: ", + module.status().message())); + } + + auto bytecode = ModuleToBytecode(module.value().get()); + if (!bytecode.ok()) { + return absl::UnknownError( + absl::StrCat("Failed to serialize MLIR module to bytecode: ", + bytecode.status().message())); + } + + return bytecode.value(); +} + +absl::StatusOr PywrapTfModuleToStablehlo( + absl::string_view module_op_str, absl::string_view input_arg_shapes_str) { + mlir::DialectRegistry registry; + RegisterAllTensorFlowDialects(registry); + mlir::MLIRContext context(registry); + context.loadAllAvailableDialects(); + + auto tf_module = mlir::parseSourceString(module_op_str, &context); + if (!tf_module) { + return absl::UnknownError("Failed to parse MLIR module"); + } + + auto mlir_file_path = ExportModule(*tf_module); + if (!mlir_file_path.ok()) { + return absl::UnknownError( + absl::StrCat("Failed to write MLIR module to file.", + mlir_file_path.status().message())); + } + + auto module = TfToStablehlo(*mlir_file_path, &context, + /*exported_model_signatures=*/{}, + /*tag_names=*/{}, input_arg_shapes_str, + /*is_input_mlir_module=*/true); + + if (!module.ok()) { + return absl::UnknownError( + absl::StrCat(" Failed to convert SavedModel to StableHLO: ", + module.status().message())); + } + + auto bytecode = ModuleToBytecode(module.value().get()); + if (!bytecode.ok()) { + return absl::UnknownError( + absl::StrCat("Failed to serialize MLIR module to bytecode: ", + bytecode.status().message())); + } + + return bytecode.value(); +} + +} // namespace mlir::tensorflow_to_stablehlo::pywrap diff --git a/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.h b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.h new file mode 100644 index 00000000000000..c79ed32b990dd6 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.h @@ -0,0 +1,67 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TO_STABLEHLO_PYTHON_PYWRAP_TENSORFLOW_TO_STABLEHLO_LIB_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TO_STABLEHLO_PYTHON_PYWRAP_TENSORFLOW_TO_STABLEHLO_LIB_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace mlir::tensorflow_to_stablehlo::pywrap { + +// Converts a TensorFlow SavedModel to a StableHLO MLIR module and serializes it +// to bytecode. +// +// Args: +// input_path: The path to the SavedModel directory. +// exported_model_signatures: Comma-separated list of exported model +// signatures to convert. tag_names: Comma-separated list of tags for loading +// SavedModel. +// input_arg_shapes_str: A string representation of input argument +// shapes for 'main' entry-point, separating tensors with ':', dimension +// with ',', and using '?' for unknown sizes. For example, +// 'input-arg-shapes=1,2::1,?' expresses argument shapes [1,2], [] and [1,?]. +// +// Returns: +// An absl::StatusOr containing the serialized bytecode of the StableHLO +// module on success, or an error status on failure. +absl::StatusOr PywrapSavedModelToStablehlo( + absl::string_view input_path, + const std::vector& exported_model_signatures, + const std::vector& tag_names, + absl::string_view input_arg_shapes_str); + +// Converts a TensorFlow MLIR module string to a StableHLO MLIR module and +// serializes it to bytecode. +// +// Args: +// module_op_str: TensorFlow MLIR module string. +// input_arg_shapes_str: A string representation of input argument +// shapes for 'main' entry-point, separating tensors with ':', dimension +// with ',', and using '?' for unknown sizes. For example, +// 'input-arg-shapes=1,2::1,?' expresses argument shapes [1,2], [] and [1,?]. +// +// Returns: +// An absl::StatusOr containing the serialized bytecode of the StableHLO +// module on success, or an error status on failure. +absl::StatusOr PywrapTfModuleToStablehlo( + absl::string_view module_op_str, absl::string_view input_arg_shapes_str); + +} // namespace mlir::tensorflow_to_stablehlo::pywrap + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TO_STABLEHLO_PYTHON_PYWRAP_TENSORFLOW_TO_STABLEHLO_LIB_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tests/test_tf_to_stablehlo.mlir b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tests/test_tf_to_stablehlo.mlir new file mode 100644 index 00000000000000..7c71e7014fa743 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tests/test_tf_to_stablehlo.mlir @@ -0,0 +1,22 @@ +// RUN: tf-to-stablehlo-translate %s --input-arg-shapes=1 -o - | FileCheck %s + +// CHECK-LABEL: func.func @main +// CHECK: %[[UQ:.*]] = stablehlo.uniform_quantize %arg0 : (tensor<1xf32>) -> tensor<1x!quant.uniform> +// CHECK: %[[BITCAST_CONVERT_0:.*]] = stablehlo.bitcast_convert %[[UQ]] : (tensor<1x!quant.uniform>) -> tensor<1xi8> +// CHECK: %[[BITCAST_CONVERT_1:.*]] = stablehlo.bitcast_convert %[[BITCAST_CONVERT_0]] : (tensor<1xi8>) -> tensor<1x!quant.uniform> +// CHECK: %[[UDQ:.*]] = stablehlo.uniform_dequantize %[[BITCAST_CONVERT_1]] : (tensor<1x!quant.uniform>) -> tensor<1xf32> +// CHECK: return %[[UDQ]] : tensor<1xf32> +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main(%arg0 : tensor) -> tensor { + %scales = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor + %zps = "tf.Const"() { value = dense<3> : tensor } : () -> tensor + + %0 = "tf.UniformQuantize"(%arg0, %scales, %zps) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor, tensor, tensor) -> tensor + %1 = "tf.UniformDequantize"(%0, %scales, %zps) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor, tensor, tensor) -> tensor + func.return %1 : tensor + } +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo.cc b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo.cc new file mode 100644 index 00000000000000..08cf8e67957c28 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo.cc @@ -0,0 +1,138 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" +#include "tensorflow/core/public/session.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace mlir { +namespace { + +// Extract the mlir TF module and optionally a ::tensorflow::SavedModelBundle +// from a saved model or from an mlir file. +absl::StatusOr ImportSavedModelOrTfMlir( + absl::string_view input_path, MLIRContext* context, + const std::vector& exported_model_signatures, + const std::vector& tag_names, bool is_input_mlir_module) { + if (is_input_mlir_module) { + std::string error_message; + std::unique_ptr file = + openInputFile(input_path, &error_message); + if (!file) { + return absl::AbortedError( + absl::StrCat("Failed to parse input MLIR model: ", error_message)); + } + + llvm::SourceMgr source_mgr; + source_mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); + auto module = parseSourceFile(source_mgr, context); + if (module->getOperation() == nullptr) { + return absl::AbortedError("Failed to parse input MLIR model."); + } + + return quant::stablehlo::ImportedMlirModuleOp(std::move(module), nullptr); + } + + std::unordered_set tag_set(tag_names.begin(), tag_names.end()); + return quant::stablehlo::SavedModelToMlirModuleOp( + input_path, tag_set, exported_model_signatures, *context); +} + +// Convert an TF module to a StableHLO module +absl::StatusOr> ConvertTFToStablehlo( + quant::stablehlo::ImportedMlirModuleOp imported_module, + absl::string_view input_path, MLIRContext* context, + const std::vector& tag_names, + absl::string_view input_arg_shapes_str, bool is_input_mlir_module) { + auto [module_op, saved_model_bundle] = std::move(imported_module); + + // Collect the names of the functions that have aliases so that they may not + // be inlined. + absl::flat_hash_set aliased_function_names; + if (!is_input_mlir_module) { + std::unordered_set tag_set(tag_names.begin(), tag_names.end()); + TF_ASSIGN_OR_RETURN( + auto function_aliases, + quant::stablehlo::GetFunctionAliases(input_path, tag_set)); + quant::stablehlo::UpdateFunctionAliases(function_aliases, *module_op); + absl::c_for_each(function_aliases, [&](const auto& aliases) { + return aliased_function_names.insert(aliases.first); + }); + } + + std::optional session; + if (saved_model_bundle) { + session = saved_model_bundle->GetSession(); + } + TF_ASSIGN_OR_RETURN(auto input_arg_shapes_vec, + TF::ParseArgumentShapes(input_arg_shapes_str)); + llvm::SmallVector> input_arg_shapes( + input_arg_shapes_vec.begin(), input_arg_shapes_vec.end()); + TF_RETURN_IF_ERROR(tensorflow::quantization::PreprocessAndFreezeGraph( + /*mlir_dump_file_prefix=*/"", /*is_inliner_run=*/true, + /*noinline_functions=*/aliased_function_names, *module_op, context, + session, + /*run_tf_to_stablehlo=*/true, /*deserialize_xla_call_module=*/false, + input_arg_shapes)); + + return std::move(module_op); +} + +} // namespace + +absl::StatusOr> TfToStablehlo( + absl::string_view input_path, MLIRContext* context, + const std::vector& exported_model_signatures, + const std::vector& tag_names, + absl::string_view input_arg_shapes_str, bool is_input_mlir_module) { + auto import_module_status = + ImportSavedModelOrTfMlir(input_path, context, exported_model_signatures, + tag_names, is_input_mlir_module); + if (!import_module_status.ok()) { + return import_module_status.status(); + } + + return ConvertTFToStablehlo(*std::move(import_module_status), input_path, + context, tag_names, input_arg_shapes_str, + is_input_mlir_module); +} + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo.h b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo.h new file mode 100644 index 00000000000000..55a579344d6c4d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo.h @@ -0,0 +1,56 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TO_STABLEHLO_TF_TO_STABLEHLO_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TO_STABLEHLO_TF_TO_STABLEHLO_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace mlir { + +// Converts a TensorFlow model (either from a SavedModel or an MLIR module) to a +// StableHLO MLIR module. +// +// Args: +// input_path: The path to the input TensorFlow SavedModel or MLIR module. +// context: The MLIR context to use for parsing or creating the MLIR module. +// exported_model_signatures: List of exported model signatures (strings) to +// convert. +// tag_names: List of tag names (strings) used for loading SavedModel. +// Ignored for MLIR input. +// input_arg_shapes_str: A string representation of input argument shapes for +// 'main' entry-point, separating tensors with ':', dimension with ',', and +// using '?' for unknown sizes. For example, 'input-arg-shapes=1,2::1,?' +// expresses argument shapes [1,2], [] and [1,?]. +// is_input_mlir_module: If true, `input_path` is treated as an MLIR +// module instead of a SavedModel. +// +// Returns: +// An absl::StatusOr containing the converted StableHLO MLIR module on +// success, or an absl::Status with an error message on failure. +absl::StatusOr> TfToStablehlo( + absl::string_view input_path, MLIRContext* context, + const std::vector& exported_model_signatures, + const std::vector& tag_names, + absl::string_view input_arg_shapes_str, bool is_input_mlir_module); + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TO_STABLEHLO_TF_TO_STABLEHLO_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo_translate.cc b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo_translate.cc new file mode 100644 index 00000000000000..6b43cc8b112313 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo_translate.cc @@ -0,0 +1,134 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" + +namespace { + +using llvm::cl::opt; + +// NOLINTNEXTLINE +opt input_path(llvm::cl::Positional, + llvm::cl::desc(""), llvm::cl::Required); + +// NOLINTNEXTLINE +opt output_filename("o", llvm::cl::desc(""), + llvm::cl::Optional, llvm::cl::init("-")); + +// NOLINTNEXTLINE +opt input_arg_shapes_str( + "input-arg-shapes", + llvm::cl::desc( + "A string representation of input argument shapes for 'main' " + "entry-point, separating tensors with ':', dimension with ',', and " + "using '?' for unknown sizes. For example, 'input-arg-shapes=1,2::1,?' " + "expresses argument shapes [1,2], [] and [1,?]"), + llvm::cl::Optional, llvm::cl::init("")); + +// NOLINTNEXTLINE +opt exported_model_signatures( + "exported-model-signatures", + llvm::cl::desc( + "Comma-separated list of exported model signatures to convert"), + llvm::cl::Optional, llvm::cl::init("serving_default")); + +// NOLINTNEXTLINE +opt tag_names( + "tags", + llvm::cl::desc("Comma-separated list of tags for loading SavedModel. " + "Ignored for MLIR input"), + llvm::cl::Optional, llvm::cl::init("serve")); + +// NOLINTNEXTLINE +opt elide_large_elements_attrs( + "e", + llvm::cl::desc( + "Elide large elements attrs while dumping the output StableHLO."), + llvm::cl::Optional, llvm::cl::init(false)); + +} // namespace + +namespace mlir { + +namespace { +// Dump the ModuleOp 'module' to the file specified using 'outputFileName' +absl::Status ExportModule(ModuleOp module) { + std::string error_msg; + auto output = openOutputFile(output_filename, &error_msg); + if (output == nullptr) { + return absl::AbortedError( + absl::StrCat("Unable to write to output path: ", error_msg)); + } + + // Export StableHLO MLIR as output + std::string result; + llvm::raw_string_ostream os(result); + OpPrintingFlags printing_flags; + if (elide_large_elements_attrs) { + printing_flags.elideLargeElementsAttrs(); + } + module.print(os, printing_flags); + os.flush(); + + output->os() << result; + output->keep(); + + return absl::OkStatus(); +} + +} // namespace +} // namespace mlir + +int main(int argc, char** argv) { + tensorflow::InitMlir y(&argc, &argv); + llvm::cl::ParseCommandLineOptions(argc, argv, + "TF Saved Model to Stablehlo converter\n"); + + mlir::DialectRegistry registry; + RegisterAllTensorFlowDialects(registry); + mlir::MLIRContext context(registry); + context.loadAllAvailableDialects(); + + bool is_input_mlir_module = absl::EndsWith(input_path, ".mlir"); + std::vector exported_model_signatures_in_vector = + absl::StrSplit(exported_model_signatures, ','); + std::vector tag_names_in_vector = absl::StrSplit(tag_names, ','); + auto module = TfToStablehlo( + input_path, &context, exported_model_signatures_in_vector, + tag_names_in_vector, input_arg_shapes_str, is_input_mlir_module); + if (!module.ok()) { + llvm::errs() << module.status().ToString() << "\n"; + return module.status().raw_code(); + } + + return mlir::ExportModule(module->get()).raw_code(); +} diff --git a/tensorflow/compiler/mlir/register_common_dialects.cc b/tensorflow/compiler/mlir/register_common_dialects.cc index b089bd9a1eb787..fe626375a8ee8f 100644 --- a/tensorflow/compiler/mlir/register_common_dialects.cc +++ b/tensorflow/compiler/mlir/register_common_dialects.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "xla/mlir/framework/ir/xla_framework.h" #include "xla/mlir_hlo/mhlo/IR/register.h" -#include "xla/service/cpu/hlo_xla_runtime_pipeline.h" namespace mlir { @@ -38,7 +37,6 @@ void RegisterCommonToolingDialects(mlir::DialectRegistry& registry) { mlir::registerAllDialects(registry); mlir::registerAllExtensions(registry); mlir::stablehlo::registerAllDialects(registry); - xla::cpu::RegisterHloXlaRuntimePipelineDialects(registry); registry.insert(); registry.insert(); diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index abfcfcfc6746e6..d32bffa13dc11a 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -79,6 +79,7 @@ 'mlir-translate', 'odml-to-stablehlo-opt', 'odml_to_stablehlo', + 'odml-converter', 'stable-quant-opt', 'tac-opt-all-backends', 'tac-translate', diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index dc75547758e11f..b4fc94d8729d68 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -42,6 +42,7 @@ 'tensorflow/compiler/mlir/lite', 'tensorflow/compiler/mlir/lite/experimental/tac', 'tensorflow/compiler/mlir/lite/stablehlo', + 'tensorflow/compiler/mlir/lite/stablehlo/odml_converter', 'tensorflow/compiler/mlir/quantization/tensorflow', 'tensorflow/compiler/mlir/tensorflow', 'tensorflow/compiler/mlir/tfrt', diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 26d5e4d52b41d7..b138c2d3efd598 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -644,6 +644,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -675,10 +676,12 @@ cc_library( ":tensorflow", ":tensorflow_op_interfaces", ":tensorflow_side_effects", + ":tensorflow_traits", ":tensorflow_types", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:FuncDialect", @@ -778,6 +781,7 @@ cc_library( hdrs = ["utils/location_utils.h"], deps = [ "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -908,6 +912,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:ml_dtypes", "@local_xla//xla:test", ], @@ -938,6 +943,7 @@ cc_library( "//tensorflow/core/util:managed_stack_trace", "@com_google_absl//absl/status", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_xla//xla/mlir/utils:error_util", ], ) @@ -1401,6 +1407,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -1473,6 +1480,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -1505,6 +1513,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -1516,6 +1525,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -1656,9 +1666,6 @@ aliased_targets = [ "export_graphdef", "import_model", "export_tf_dialect_op", - "translate_tf_dialect_op", - "mlir_roundtrip_pass", - "mlir_roundtrip_pass_registration", "mlir_roundtrip_flags", "mlir_import_options", "translate_lib", diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc index 348316e2648ccb..267bc48d17e06d 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc @@ -20,29 +20,36 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "mlir/Analysis/CallGraph.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { namespace TF { @@ -74,7 +81,7 @@ class BacktrackAnalysisInfo { // the result cannot be backtracked to a region argument, returns // std::nullopt. std::optional GetArg(int result_index) const { - if (auto arg = GetValue(result_index).dyn_cast()) + if (auto arg = mlir::dyn_cast(GetValue(result_index))) if (arg.getParentBlock() == ®ion_->front()) return arg.getArgNumber(); return std::nullopt; } @@ -191,7 +198,7 @@ BacktrackAnalysis::BacktrackAnalysis( // possible. Value BacktrackAnalysis::BacktrackValue(Value value) { while (Operation* op = value.getDefiningOp()) { - int res_index = value.cast().getResultNumber(); + int res_index = mlir::cast(value).getResultNumber(); if (auto graph = dyn_cast(op)) { value = graph.GetFetch().getOperand(res_index); } else if (auto island = dyn_cast(op)) { diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h index 7afec29bc5df75..c49852c1864763 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h @@ -21,14 +21,21 @@ limitations under the License. #include #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/per_function_aggregate_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc index 5ceda80490f688..e27d0405d7e8f1 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc @@ -46,7 +46,7 @@ ResourceConstructingOps ResourceConstructingOps::EntryState( return ResourceConstructingOps(); } ResourceConstructingOps ResourceConstructingOps::EntryState(Value value) { - if (auto barg = value.dyn_cast()) { + if (auto barg = mlir::dyn_cast(value)) { if (func::FuncOp func = dyn_cast(barg.getOwner()->getParentOp())) { SymbolTable symbol_table(func->getParentOfType()); @@ -87,7 +87,7 @@ IsComposite IsComposite::EntryState(MLIRContext *context) { IsComposite IsComposite::EntryState(Value value) { IsComposite result; - if (auto barg = value.dyn_cast()) { + if (auto barg = mlir::dyn_cast(value)) { if (func::FuncOp func = dyn_cast(barg.getOwner()->getParentOp())) { if (func.getArgAttr(barg.getArgNumber(), kCompositeDevice)) { diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h index 0cf3611af1d20c..1e68ac41d25b54 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h @@ -25,9 +25,11 @@ limitations under the License. #include "llvm/Support/Debug.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project #include "mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project +#include "mlir/Analysis/DataFlowFramework.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc index e1a984ea69bc67..372446641382ac 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc @@ -16,9 +16,18 @@ limitations under the License. #include +#include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" @@ -29,8 +38,8 @@ namespace TF { namespace { bool IsResourceType(Type type) { - if (auto tensor_type = type.dyn_cast()) { - return tensor_type.getElementType().isa(); + if (auto tensor_type = mlir::dyn_cast(type)) { + return mlir::isa(tensor_type.getElementType()); } return false; } @@ -44,10 +53,9 @@ func::FuncOp GetSessionInitializerFunc(ModuleOp module) { auto session_init_op = tf_saved_model::GetSessionInitializerOp(module); if (session_init_op && !session_init_op.getInitializers().empty()) { SymbolTable symbol_table(module); - func::FuncOp init_func_op = - symbol_table.lookup(session_init_op.getInitializers()[0] - .cast() - .getValue()); + func::FuncOp init_func_op = symbol_table.lookup( + mlir::cast(session_init_op.getInitializers()[0]) + .getValue()); return init_func_op; } return nullptr; diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h b/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h index 9817b290c4cbdb..738d8c1df3d395 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h @@ -22,6 +22,11 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index c95dd020497385..179b3979348161 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -26,25 +26,32 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_map.h" +#include "absl/log/log.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -73,7 +80,7 @@ const ResourceIdSet& UnknownResourceSet() { const ResourceIdSet& GetResourceUniqueIdsOrUnknown( Value value, const ResourceAliasAnalysis::Info& alias_analysis) { - if (!getElementTypeOrSelf(value.getType()).isa() || + if (!mlir::isa(getElementTypeOrSelf(value.getType())) || alias_analysis.IsUnknownResource(value)) return UnknownResourceSet(); return alias_analysis.GetResourceUniqueIds(value); } @@ -145,7 +152,7 @@ bool MayHaveSideEffect(Operation* op) { bool ShouldUseResourceAliasAnalysis( const MemoryEffects::EffectInstance& effect) { Value value = effect.getValue(); - if (value && getElementTypeOrSelf(value.getType()).isa()) { + if (value && mlir::isa(getElementTypeOrSelf(value.getType()))) { // For value-based effects on resource values we can use resource alias // analysis. return true; diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h index 97fcd30d36d02f..feb90de18857b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h @@ -23,12 +23,18 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/per_function_aggregate_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index e9a35b1221c2a4..7275aee19e49f4 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -121,7 +121,7 @@ class MlirTensor : public TracingTensorHandle { Value getValue() { return value_; } Type getElementType() { - return value_.getType().cast().getElementType(); + return mlir::cast(value_.getType()).getElementType(); } // For LLVM style RTTI. @@ -340,11 +340,11 @@ Status MlirAbstractOp::SetOpName(const char* const op_name) { Status MlirAbstractOp::AddRef(Type type, Type* output_type) { Type elt_type = getElementTypeOrSelf(type); - if (elt_type.isa()) { + if (mlir::isa(elt_type)) { return InvalidArgument("Requested reference to a reference type"); } elt_type = TensorFlowRefType::get(elt_type); - if (RankedTensorType tensor_type = type.dyn_cast()) { + if (RankedTensorType tensor_type = mlir::dyn_cast(type)) { *output_type = RankedTensorType::get(tensor_type.getShape(), elt_type); } *output_type = UnrankedTensorType::get(elt_type); @@ -373,11 +373,11 @@ Status MlirAbstractOp::Create(ArrayRef operands, return InvalidArgument("Missing attribute '", output_arg.number_attr(), "' required for output list '", output_arg.name(), "'"); - if (!repeats_attr.isa()) + if (!mlir::isa(repeats_attr)) return InvalidArgument("Attribute '", output_arg.number_attr(), "' required for output list '", output_arg.name(), "' isn't an integer"); - int64_t repeats = repeats_attr.cast().getInt(); + int64_t repeats = mlir::cast(repeats_attr).getInt(); if (!output_arg.type_attr().empty()) { // Same type repeated "repeats" times. @@ -386,7 +386,7 @@ Status MlirAbstractOp::Create(ArrayRef operands, return InvalidArgument("Missing attribute '", output_arg.type_attr(), "' required for output '", output_arg.name(), "'"); - TypedAttr type_attr = attr.dyn_cast(); + TypedAttr type_attr = mlir::dyn_cast(attr); if (!type_attr) return InvalidArgument("Attribute '", output_arg.type_attr(), "' required for output '", output_arg.name(), @@ -410,7 +410,7 @@ Status MlirAbstractOp::Create(ArrayRef operands, return InvalidArgument("Missing attribute '", output_arg.type_attr(), "' required for output '", output_arg.name(), "'"); - TypeAttr type_attr = attr.dyn_cast(); + TypeAttr type_attr = mlir::dyn_cast(attr); if (!type_attr) return InvalidArgument("Attribute '", output_arg.type_attr(), "' required for output '", output_arg.name(), @@ -423,13 +423,13 @@ Status MlirAbstractOp::Create(ArrayRef operands, return InvalidArgument( "Missing attribute '", output_arg.type_list_attr(), "' required for output '", output_arg.name(), "'"); - ArrayAttr array_attr = attr.dyn_cast(); + ArrayAttr array_attr = mlir::dyn_cast(attr); if (!array_attr) return InvalidArgument("Attribute '", output_arg.type_list_attr(), "' required for output '", output_arg.name(), "' isn't an array attribute"); for (Attribute attr : array_attr) { - TypeAttr type_attr = attr.dyn_cast(); + TypeAttr type_attr = mlir::dyn_cast(attr); if (!type_attr) return InvalidArgument("Array Attribute '", output_arg.type_list_attr(), diff --git a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.cc index dba58f17ccb029..9a1db50ff6b732 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -45,11 +46,12 @@ _TfrtGetResourceOp::GetResourceHandleValueAndIdList( for (const auto &iter : llvm::enumerate(getResults())) { auto index = iter.index(); - if (getElementTypeOrSelf(iter.value().getType()).isa()) { + if (mlir::isa( + getElementTypeOrSelf(iter.value().getType()))) { resource_vec.push_back(GetResourceHandleValueAndIdBase( - getContainer()[index].cast().getValue(), - getSharedName()[index].cast().getValue(), device, - getResults()[index], resource_handle_id_map, next_id)); + mlir::cast(getContainer()[index]).getValue(), + mlir::cast(getSharedName()[index]).getValue(), + device, getResults()[index], resource_handle_id_map, next_id)); } } return resource_vec; @@ -100,16 +102,16 @@ mlir::LogicalResult IfrtCallOp::verify() { } for (mlir::Value arg : getArgs()) { - if (mlir::getElementTypeOrSelf(arg.getType()) - .isa()) { + if (mlir::isa( + mlir::getElementTypeOrSelf(arg.getType()))) { return emitOpError() << "does not support passing '!tf.resource' values as arguments"; } } for (mlir::Value result : getResults()) { - if (mlir::getElementTypeOrSelf(result.getType()) - .isa()) { + if (mlir::isa( + mlir::getElementTypeOrSelf(result.getType()))) { return emitOpError() << "does not support returning '!tf.resource' values as results"; } @@ -118,12 +120,13 @@ mlir::LogicalResult IfrtCallOp::verify() { // Verify variable_arg_indices is sorted in ascending order. int64_t prev_index = -1; for (auto arg_index_attr : getVariableArgIndicesAttr()) { - if (!arg_index_attr.isa_and_nonnull()) { + if (!mlir::isa_and_nonnull(arg_index_attr)) { return emitOpError() << "variable_arg_indices must be an integer"; } - int64_t index = - arg_index_attr.dyn_cast().getValue().getSExtValue(); + int64_t index = mlir::dyn_cast(arg_index_attr) + .getValue() + .getSExtValue(); if (index < 0) { return emitOpError() << "variable_arg_indices must be positive"; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.td index 0c783c01caa287..e46a6500dfd516 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.td @@ -94,29 +94,26 @@ Empty strings indicate that they are non-partitioned tensors.}]>:$shape_and_slic def TF_IfrtLoadVariableOp : TF_Op<"IfrtLoadVariable", [Pure]> { - let summary = "Loads a variable tensor as an IFRT array"; + let summary = "Loads a restored variable tensor as a tensor future"; let description = [{ - This op loads a variable tensor as an IFRT array and binds it with the specified name. + This op loads a restored variable tensor as a tensor future. It is a + replacement of `tf.ReadVariableOp`. - This op is an replacement of `tf.ReadVariableOp` in the case that a constant - variable tensor is an input to the tpu program invoked by `tf.IfrtCall`. + This op returns a scalar string tensor containing the restored variable name, which can be + used as a key within the runtime, as well as a future for the tensor. - After a `tf.ReadVariableOp` is lowered into `tf.IfrtLoadVariableOp`, the `tf.IfrtCall` kernel - will bind the loaded IFRT array by name with the tpu program's input. - - `tf.IfrtLoadVariableOp` converts the tensor into an IFRT array based on device and sharding - configuration specified in `VariableDeviceShardingConfigProto`. - - This op returns a scalar string tensor containing the loaded variable name, which can be - used as a key to look for the loaded IFRT array in runtime and a restored tensor, which - maybe lowered to a future by runtime. + The `tf.IfrtCall` kernel uses the output $array_key. + Other ops executed by TFRT may make use of $tensor_future. }]; + // TODO(b/339423851) Redefine the IfrtLoadVariableOp as it doesn't require the + // sharding info in the attribute if multihost do not need this info. let arguments = (ins Arg:$variable, DefaultValuedStrAttr:$device_sharding_config_proto_text, - DefaultValuedAttr:$name + DefaultValuedAttr:$name, + DefaultValuedAttr:$used_by_host ); let results = (outs diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.cc index f5284a0ef3cf96..9a78a1a83ae214 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { @@ -27,12 +28,12 @@ namespace TF { // Verifies an reduction op's `input` and reduction `dims`. LogicalResult VerifyReductionInputAndDims(Value input, Value dims, Location loc) { - auto dims_type = dims.getType().dyn_cast(); + auto dims_type = mlir::dyn_cast(dims.getType()); if (!dims_type) return success(); if (dims_type.getRank() > 1) return emitError(loc, "dimensions can only be 0D or 1D tensor"); - auto input_type = input.getType().dyn_cast(); + auto input_type = mlir::dyn_cast(input.getType()); if (!input_type) return success(); int64_t rank = input_type.getRank(); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h index aa0f84eb122e2b..64b5d2e141f13d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/TypeRange.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { @@ -60,10 +61,10 @@ template < OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr> OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, ArrayRef operands) { - auto lhs_type = arithmetic_op.getX().getType().template cast(); - auto rhs_type = arithmetic_op.getY().getType().template cast(); + auto lhs_type = mlir::cast(arithmetic_op.getX().getType()); + auto rhs_type = mlir::cast(arithmetic_op.getY().getType()); auto result_type = - arithmetic_op.getResult().getType().template cast(); + mlir::cast(arithmetic_op.getResult().getType()); // We can fold arithmetic operation only of we can prove that we will not // accidentally hide a broadcasting error. @@ -86,8 +87,8 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, // Check that we have a constant operand on one side (candidate for identity). const bool is_commutative = (std::is_same::value || std::is_same::value); - auto lhs_attr = operands[0].dyn_cast_or_null(); - auto rhs_attr = operands[1].dyn_cast_or_null(); + auto lhs_attr = mlir::dyn_cast_or_null(operands[0]); + auto rhs_attr = mlir::dyn_cast_or_null(operands[1]); if (!rhs_attr && !(is_commutative && lhs_attr)) return {}; // Mul and Div ops have identity value one while AddV2 and SubOp have identity @@ -100,9 +101,9 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, Type element_ty = lhs_type.getElementType(); Attribute identity_attr; - if (auto ty = element_ty.template dyn_cast()) { + if (auto ty = mlir::dyn_cast(element_ty)) { identity_attr = FloatAttr::get(ty, static_cast(identity)); - } else if (auto ty = element_ty.template dyn_cast()) { + } else if (auto ty = mlir::dyn_cast(element_ty)) { identity_attr = IntegerAttr::get(ty, static_cast(identity)); } else { return {}; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index 5d145c85a68a06..df887ce453b8ea 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -99,7 +99,8 @@ struct TFInlinerInterface : public DialectInlinerInterface { Operation* materializeCallConversion(OpBuilder& builder, Value input, Type result_type, Location conversion_loc) const final { - if (!result_type.isa() || !input.getType().isa()) + if (!mlir::isa(result_type) || + !mlir::isa(input.getType())) return nullptr; return builder.create(conversion_loc, result_type, input, /*truncate=*/builder.getBoolAttr(false)); @@ -307,7 +308,7 @@ ParseResult SetReplicateOpOperands( llvm::ArrayRef region_arg_types, int32_t* n) { for (const auto& attr : state->attributes) if (attr.getName().strref() == "n") - if (auto n_attr = attr.getValue().dyn_cast()) + if (auto n_attr = mlir::dyn_cast(attr.getValue())) *n = n_attr.getInt(); if (*n < 2) @@ -507,13 +508,14 @@ LogicalResult ReplicateOp::verify() { // Check number of devices, if set, matches `n`. if (op.getDevices().has_value()) { for (auto device_attr : op.getDevices().value().getValue()) { - auto device_list = device_attr.getValue().dyn_cast_or_null(); + auto device_list = + mlir::dyn_cast_or_null(device_attr.getValue()); if (!device_list) return op.emitError() << "expects 'devices' to be a map alias and device name list."; bool is_device_string = llvm::all_of(device_list, [](Attribute attr) { - return attr.dyn_cast_or_null(); + return mlir::dyn_cast_or_null(attr); }); if (!is_device_string) return op.emitOpError() << "expects 'devices' to be a consists of " @@ -747,8 +749,8 @@ static LogicalResult EliminatePassThroughResults(ClusterOp op, // Old bridge only removes unsupported TPU types (only string for now) // during outside compilation extraction so this should be enough for // the parity. - bool is_unsupported_type = getElementTypeOrSelf(operand.get().getType()) - .isa(); + bool is_unsupported_type = mlir::isa( + getElementTypeOrSelf(operand.get().getType())); Value result = operand.get(); if (is_unsupported_type && result.getParentBlock() != &body && !is_used_for_resource_write) { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index f7c35420c22b4a..f48e1570933cf9 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/FoldUtils.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project @@ -119,11 +120,11 @@ Type TensorFlowExecutorDialect::parseType(DialectAsmParser &parser) const { void TensorFlowExecutorDialect::printType(Type type, DialectAsmPrinter &os) const { - if (type.isa()) { + if (mlir::isa(type)) { os << "control"; return; } - if (type.isa()) { + if (mlir::isa(type)) { os << "token"; return; } @@ -141,7 +142,7 @@ namespace { LogicalResult VerifyControlOperandsAfterAllData(Operation *op) { bool found_control = false; for (int operand_idx : llvm::seq(0, op->getNumOperands())) { - if (op->getOperand(operand_idx).getType().isa()) { + if (mlir::isa(op->getOperand(operand_idx).getType())) { found_control = true; continue; } @@ -192,7 +193,7 @@ LogicalResult GraphOp::verify() { Value operand = fetch.getOperand(i); // Break out of the loop at the first control operand encountered. const int64_t num_results = graph.getNumResults(); - if (operand.getType().isa()) { + if (mlir::isa(operand.getType())) { if (i != num_results) return fetch.emitOpError() << "operand #" << i @@ -241,7 +242,7 @@ ParseResult GraphOp::parse(OpAsmParser &parser, OperationState &result) { // the fetch operation. result.types.reserve(fetch.getNumOperands()); for (Type type : fetch.getOperandTypes()) { - if (type.isa()) break; + if (mlir::isa(type)) break; result.types.push_back(type); } @@ -403,8 +404,8 @@ ParseResult SwitchOp::parse(OpAsmParser &parser, OperationState &result) { // fully qualified) or a short form with a single type (in which case the data // input and the outputs are all using this type and predicate is tensor // type). - if (types.front().isa()) { - FunctionType type = types.front().cast(); + if (mlir::isa(types.front())) { + FunctionType type = mlir::cast(types.front()); if (type.getNumInputs() < 2) return parser.emitError(parser.getNameLoc()) << " expects a single data type and a predicate"; @@ -439,7 +440,7 @@ void SwitchOp::print(OpAsmPrinter &p) { p << " : "; if (getTrueOutput().getType() != data_operand_ty || getFalseOutput().getType() != data_operand_ty || - getPredicate().getType().isa()) { + mlir::isa(getPredicate().getType())) { p.printFunctionalType(getOperation()); } else { p << getType(0); @@ -465,16 +466,16 @@ LogicalResult SwitchNOp::verify() { // Check that operand can be broadcasted to each output type. auto operand0_type = switchn.getOperand(0).getType(); - TensorType operand0_tensor_type = operand0_type.dyn_cast(); + TensorType operand0_tensor_type = mlir::dyn_cast(operand0_type); if (!operand0_tensor_type) { return switchn.emitOpError() << "expects data operand to have tensor type but got " << operand0_type; } for (Type output_type : switchn.getResultTypes()) { - if (output_type.isa()) break; + if (mlir::isa(output_type)) break; - TensorType output_tensor_type = output_type.dyn_cast(); + TensorType output_tensor_type = mlir::dyn_cast(output_type); if (!output_tensor_type) { return switchn.emitOpError() << "expects outputs to have tensor type but got " << output_type; @@ -483,10 +484,10 @@ LogicalResult SwitchNOp::verify() { // If the output type is a ref type, then the operand type should also be of // the same ref type. However, if the output type is a non-ref type T, then // the operand can be tensor of type T or T_REF. - bool is_output_ref = - output_tensor_type.getElementType().isa(); - if (is_output_ref && !operand0_tensor_type.getElementType() - .isa()) { + bool is_output_ref = mlir::isa( + output_tensor_type.getElementType()); + if (is_output_ref && !mlir::isa( + operand0_tensor_type.getElementType())) { return switchn.emitOpError() << "expects same operand and output element type but got " << operand0_tensor_type << " vs " << output_tensor_type; @@ -573,24 +574,24 @@ LogicalResult MergeOp::verify() { return merge.emitOpError() << "expects at least one operand"; Type data_type = merge.getOperand(0).getType(); - if (data_type.isa()) + if (mlir::isa(data_type)) return merge.emitOpError() << "expects a non-control input"; // Check that each operand can be individually broadcasted to the output type. Type output_type = merge.getOutput().getType(); - TensorType output_tensor_ty = output_type.dyn_cast(); + TensorType output_tensor_ty = mlir::dyn_cast(output_type); if (!output_tensor_ty) { return merge.emitOpError() << "expects output to have tensor type but got " << output_type; } bool is_output_ref = - output_tensor_ty.getElementType().isa(); + mlir::isa(output_tensor_ty.getElementType()); for (Type operand_type : merge.getOperandTypes()) { - if (operand_type.isa()) break; + if (mlir::isa(operand_type)) break; // TODO(hinsu): Update ControlOperandsAfterAllData trait to verify this // constraint. - TensorType operand_tensor_ty = operand_type.dyn_cast(); + TensorType operand_tensor_ty = mlir::dyn_cast(operand_type); if (!operand_tensor_ty) return merge.emitOpError() << "expects data operands to have tensor type but got " @@ -599,8 +600,8 @@ LogicalResult MergeOp::verify() { // If output type is a ref type then all operand types should also be of the // same ref type. However, if the output type is a non-ref type T, operands // can be tensor of type T or T_REF. - if (is_output_ref && - !operand_tensor_ty.getElementType().isa()) { + if (is_output_ref && !mlir::isa( + operand_tensor_ty.getElementType())) { return merge.emitOpError() << "expects same operand and output element type but got " << operand_tensor_ty << " vs " << output_tensor_ty; @@ -624,7 +625,7 @@ void MergeOp::print(OpAsmPrinter &p) { Type output_type = getOutput().getType(); for (Type operand_type : getOperandTypes()) { - if (operand_type.isa()) break; + if (mlir::isa(operand_type)) break; num_data_operands++; if (operand_type != output_type) { @@ -660,7 +661,7 @@ ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) { // Support parsing either a functional type (in which case all the types are // fully qualified) or a short form with a single type (in which case the data // inputs and the output are all using this type). - if (FunctionType type = types.front().dyn_cast()) { + if (FunctionType type = mlir::dyn_cast(types.front())) { result.types.assign(type.getResults().begin(), type.getResults().end()); types.assign(type.getInputs().begin(), type.getInputs().end()); } else { @@ -747,7 +748,7 @@ ParseResult EnterOp::parse(OpAsmParser &parser, OperationState &result) { // Support parsing either a functional type (in which case all the types are // fully qualified) or a short form with a single type (in which case the data // input and the outputs are all using this type). - if (FunctionType type = types.front().dyn_cast()) { + if (FunctionType type = mlir::dyn_cast(types.front())) { // One data input, and any number of control inputs. if (type.getNumInputs() >= 1) { result.types.assign(type.getResults().begin(), type.getResults().end()); @@ -876,7 +877,7 @@ ParseResult LoopCondOp::parse(OpAsmParser &parser, OperationState &result) { // fully qualified) or a short form with a single type (in which case the data // input and the outputs are all using this type). Type control_type = ControlType::get(parser.getBuilder().getContext()); - if (FunctionType type = types.front().dyn_cast()) { + if (FunctionType type = mlir::dyn_cast(types.front())) { if (llvm::count_if(type.getInputs(), [=](Type type) { return type != control_type; }) != 1) return parser.emitError(parser.getNameLoc()) @@ -959,14 +960,14 @@ struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern { llvm::SmallVector new_rets; for (Value operand : fetch_op.getFetches()) { // Control results should not be propagated out. - if (operand.getType().isa()) break; + if (mlir::isa(operand.getType())) break; if (operand.getDefiningOp() != island_op) { // Operand is not from island, simply propagate it out. new_rets.push_back(operand); } else { // Lookup yield operand in island for inner op result. - auto result = operand.cast(); + auto result = mlir::cast(operand); new_rets.push_back(yield_op.getOperand(result.getResultNumber())); } } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 6ba660297366ba..b3d9200aa5d00d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -6274,6 +6274,10 @@ Note that on CPU, if an out of bound index is found, an error is returned. On GPU, if an out of bound index is found, a 0 is stored in the corresponding output value. +Note that on TPU, if any dimension of `params` is of size 0 then the output will +be the expected shape filled with zeros. On CPU and GPU an error will be +returned. + See also `tf.batch_gather` and `tf.gather_nd`. }]; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index d3026b02878741..373586ae837a3f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -191,7 +191,8 @@ struct TFInlinerInterface : public DialectInlinerInterface { Operation *materializeCallConversion(OpBuilder &builder, Value input, Type result_type, Location conversion_loc) const final { - if (!result_type.isa() || !input.getType().isa()) + if (!mlir::isa(result_type) || + !mlir::isa(input.getType())) return nullptr; return builder.create(conversion_loc, result_type, input, /*truncate=*/builder.getBoolAttr(false)); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index a763b50ccd92cf..f8fcf569c9837a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -2431,4 +2431,300 @@ def TF_StoreMinibatchStatisticsInFdoOp : TF_Op<"StoreMinibatchStatisticsInFdo", let results = (outs ); } + +def TF_ConvertToListOfSparseCoreCooTensorsOp : TF_Op<"ConvertToListOfSparseCoreCooTensors", [Pure, SameVariadicOperandSize, SameVariadicResultSize]> { + let summary = "An op which converts the sparse/ragged/dense tensor into a list of COO tensor for each SparseCore."; + + let arguments = (ins + TF_Int32Tensor:$indices_or_row_splits, + TF_Int32Tensor:$values, + TF_Float32Tensor:$weights, + + ConfinedAttr]>:$sample_count, + ConfinedAttr]>:$row_offset, + ConfinedAttr]>:$col_offset, + ConfinedAttr]>:$col_shift, + ConfinedAttr]>:$num_sc_shards, + ConfinedAttr]>:$stacked_table_sample_count, + StrAttr:$combiner + ); + + let results = (outs + Variadic:$row_ids_list, + Variadic:$col_ids_list, + Variadic:$gains_list + ); + + TF_DerivedResultSizeAttr num_sc_per_chip = TF_DerivedResultSizeAttr<0>; +} + + +def TF_SortListOfSparseCoreCooTensorsOp : TF_Op<"SortListOfSparseCoreCooTensors", [Pure, SameVariadicOperandSize]> { + let summary = "An op which sorts each COO tensors in the list by which SparseCore the id will go to. This op should be used along with the ConvertToSparseCoreCsrWrappedCooTensorOp."; + + let arguments = (ins + Variadic:$row_ids_list, + Variadic:$col_ids_list, + Variadic:$gains_list, + + I64ArrayAttr:$sample_count_list, + I64ArrayAttr:$col_offset_list, + ConfinedAttr]>:$num_replica, + ConfinedAttr]>:$table_vocab_size, + ConfinedAttr]>:$feature_width, + ConfinedAttr]>:$num_sc_per_chip, + ConfinedAttr]>:$max_ids_per_sparse_core, + ConfinedAttr]>:$max_unique_ids_per_sparse_core, + StrAttr:$table_name + ); + + let results = (outs + TF_Int32Tensor:$sorted_row_ids, + TF_Int32Tensor:$sorted_col_ids, + TF_Float32Tensor:$sorted_gains, + TF_Int32Tensor:$id_counts + ); + + // N represents the number of COO tensors in the list. + TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<1>; +} + + +def TF_ConvertToSparseCoreCsrWrappedCooTensorOp : TF_Op<"ConvertToSparseCoreCsrWrappedCooTensorOp", [Pure, SameVariadicOperandSize]> { + let summary = "An op which converts the sorted coo tensor into sparse core CSR wrapped COO format."; + + let arguments = (ins + Variadic:$sorted_row_ids_list, + Variadic:$sorted_col_ids_list, + Variadic:$sorted_gains_list, + Variadic:$id_counts_list, + TF_Int64Tensor:$splits, + + ConfinedAttr]>:$sample_count_per_sc, + ConfinedAttr]>:$num_replica, + ConfinedAttr]>:$max_minibatches_per_sc, + ConfinedAttr]>:$max_ids_per_chip_per_sample, + ConfinedAttr]>:$table_vocab_size, + ConfinedAttr]>:$feature_width, + StrAttr:$table_name, + BoolAttr:$allow_id_dropping + ); + + let results = (outs + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Float32Tensor:$sorted_gains, + TF_Int32Tensor:$row_pointers_unpadded_size, + TF_Int32Tensor:$ids_unpadded_size, + TF_Int32Tensor:$num_minibatches_per_sc + ); + + TF_DerivedOperandSizeAttr num_sc_per_chip = TF_DerivedOperandSizeAttr<1>; +} + + +def TF_GetStatsFromListOfSparseCoreCooTensorsOp : TF_Op<"GetStatsFromListOfSparseCoreCooTensors", [Pure, SameVariadicOperandSize]> { + let summary = "An op which computes the max_ids/uniques for a given table."; + + let arguments = (ins + Variadic:$row_ids_list, + Variadic:$col_ids_list, + Variadic:$gains_list, + + I64ArrayAttr:$sample_count_list, + I64ArrayAttr:$col_offset_list, + ConfinedAttr]>:$num_replica, + ConfinedAttr]>:$table_vocab_size, + ConfinedAttr]>:$feature_width, + ConfinedAttr]>:$num_sc_per_chip, + StrAttr:$table_name + ); + + let results = (outs + TF_Int32Tensor:$max_ids_per_sparse_core, + TF_Int32Tensor:$max_unique_ids_per_sparse_core + ); + + TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<1>; +} + +def TF_XlaSparseDenseMatmulWithStaticBufferSizeOp : TF_Op<"XlaSparseDenseMatmulWithStaticBufferSize", [Pure]> { + let summary = "A XLA op which performs the dense-sparse matrix multiplication."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Float32Tensor:$sorted_gains, + TF_Float32Tensor:$embedding_table, + TF_Int32Tensor:$num_minibatches_per_physical_sparse_core, + + ConfinedAttr]>:$input_size, + OptionalAttr:$quantization_config_low, + OptionalAttr:$quantization_config_high, + OptionalAttr:$quantization_config_num_buckets, + ConfinedAttr]>:$max_ids_per_sparse_core, + ConfinedAttr]>:$max_unique_ids_per_sparse_core, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$activations + ); +} + + +def TF_XlaSparseDenseMatmulGradWithSgdAndStaticBufferSizeOp : TF_Op<"XlaSparseDenseMatmulGradWithSgdAndStaticBufferSize", [Pure]> { + let summary = "A XLA op which performs the SGD optimizer update for the dense-sparse matrix multiplication."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Float32Tensor:$sorted_gains, + TF_Float32Tensor:$activation_gradients, + TF_Float32Tensor:$learning_rate, + TF_Float32Tensor:$embedding_table, + TF_Int32Tensor:$num_minibatches_per_physical_sparse_core, + + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + ConfinedAttr]>:$max_ids_per_sparse_core, + ConfinedAttr]>:$max_unique_ids_per_sparse_core, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table + ); +} + +def XlaSparseDenseMatmulGradWithAdagradAndStaticBufferSizeOp : TF_Op<"XlaSparseDenseMatmulGradWithAdagradAndStaticBufferSize", [Pure]> { + let summary = "A XLA op which performs the Adagrad optimizer update for the dense-sparse matrix multiplication."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Float32Tensor:$sorted_gains, + TF_Float32Tensor:$activation_gradients, + TF_Float32Tensor:$learning_rate, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$accumulator, + TF_Int32Tensor:$num_minibatches_per_physical_sparse_core, + + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + ConfinedAttr]>:$max_ids_per_sparse_core, + ConfinedAttr]>:$max_unique_ids_per_sparse_core, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table, + TF_Float32Tensor:$updated_accumulator + ); +} + +def XlaSparseDenseMatmulGradWithAdagradMomentumAndStaticBufferSizeOp : TF_Op<"XlaSparseDenseMatmulGradWithAdagradMomentumAndStaticBufferSize", [Pure]> { + let summary = "A XLA op which performs the Adagrad momentumoptimizer update for the dense-sparse matrix multiplication."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Float32Tensor:$sorted_gains, + TF_Float32Tensor:$activation_gradients, + TF_Float32Tensor:$learning_rate, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$accumulator, + TF_Float32Tensor:$momenta, + TF_Int32Tensor:$num_minibatches_per_physical_sparse_core, + + BoolAttr:$use_nesterov, + F32Attr:$exponent, + F32Attr:$beta1, + F32Attr:$beta2, + F32Attr:$epsilon, + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + ConfinedAttr]>:$max_ids_per_sparse_core, + ConfinedAttr]>:$max_unique_ids_per_sparse_core, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table, + TF_Float32Tensor:$updated_accumulator, + TF_Float32Tensor:$updated_momenta + ); +} + +def XlaSparseDenseMatmulGradWithAdamAndStaticBufferSizeOp : TF_Op<"XlaSparseDenseMatmulGradWithAdamAndStaticBufferSize", [Pure]> { + let summary = "A XLA op which performs the Adam optimizer update for the dense-sparse matrix multiplication."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Float32Tensor:$sorted_gains, + TF_Float32Tensor:$activation_gradients, + TF_Float32Tensor:$learning_rate, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$momenta, + TF_Float32Tensor:$velocity, + TF_Int32Tensor:$num_minibatches_per_physical_sparse_core, + + BoolAttr:$use_sum_inside_sqrt, + F32Attr:$beta1, + F32Attr:$beta2, + F32Attr:$epsilon, + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + ConfinedAttr]>:$max_ids_per_sparse_core, + ConfinedAttr]>:$max_unique_ids_per_sparse_core, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table, + TF_Float32Tensor:$updated_momenta, + TF_Float32Tensor:$updated_velocity + ); +} + +def XlaSparseDenseMatmulGradWithFtrlAndStaticBufferSizeOp : TF_Op<"XlaSparseDenseMatmulGradWithFtrlAndStaticBufferSize", [Pure]> { + let summary = "A XLA op which performs the Ftrl optimizer update for the dense-sparse matrix multiplication."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Float32Tensor:$sorted_gains, + TF_Float32Tensor:$activation_gradients, + TF_Float32Tensor:$learning_rate, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$accumulator, + TF_Float32Tensor:$linear, + TF_Int32Tensor:$num_minibatches_per_physical_sparse_core, + + BoolAttr:$multiply_linear_by_learning_rate, + F32Attr:$beta, + F32Attr:$learning_rate_power, + F32Attr:$l1_regularization_strength, + F32Attr:$l2_regularization_strength, + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + ConfinedAttr]>:$max_ids_per_sparse_core, + ConfinedAttr]>:$max_unique_ids_per_sparse_core, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table, + TF_Float32Tensor:$updated_accumulator, + TF_Float32Tensor:$updated_linear + ); +} #endif // TF_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 988c749adb8cc6..36fb36a3d451c6 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -160,12 +160,12 @@ OpFoldResult AddNOp::fold(FoldAdaptor adaptor) { int non_zero_index = -1; auto IsKnownZero = [](Attribute attr) { if (!attr) return false; - auto splat = attr.dyn_cast(); + auto splat = mlir::dyn_cast(attr); if (!splat) return false; Type element_ty = splat.getType().getElementType(); - if (element_ty.isa()) + if (mlir::isa(element_ty)) return splat.getSplatValue().isZero(); - if (element_ty.isa()) + if (mlir::isa(element_ty)) return splat.getSplatValue().getSExtValue() == 0; return false; }; @@ -180,13 +180,13 @@ OpFoldResult AddNOp::fold(FoldAdaptor adaptor) { } // Only fold when the result shape is fully static. - auto result_ty = getType().dyn_cast(); + auto result_ty = mlir::dyn_cast(getType()); if (!result_ty || !result_ty.hasStaticShape()) return {}; if (non_zero_index == -1) { return SplatElementsAttr::get( - result_ty, - operands.begin()->cast().getSplatValue()); + result_ty, mlir::cast(*operands.begin()) + .getSplatValue()); } // Check the non-zero operand's shape matches the result shape. @@ -423,7 +423,7 @@ LogicalResult BatchToSpaceOp::verify() { int64_t block_size = op.getBlockSize(); llvm::SmallVector input_shape(4, ShapedType::kDynamic); - auto input_type = op.getInput().getType().cast(); + auto input_type = mlir::cast(op.getInput().getType()); if (input_type.hasRank()) { if (input_type.getRank() != 4) return op.emitOpError() @@ -442,7 +442,7 @@ LogicalResult BatchToSpaceOp::verify() { input_type.getShape().end()); } - auto crops_type = op.getCrops().getType().cast(); + auto crops_type = mlir::cast(op.getCrops().getType()); if (crops_type.hasRank()) { if (crops_type.getRank() != 2) return op.emitOpError() @@ -477,7 +477,7 @@ LogicalResult BatchToSpaceOp::verify() { } } - auto output_type = op.getOutput().getType().cast(); + auto output_type = mlir::cast(op.getOutput().getType()); if (output_type.hasRank()) { if (output_type.getRank() != 4) return op.emitOpError() @@ -567,8 +567,8 @@ void BatchToSpaceOp::getCanonicalizationPatterns(RewritePatternSet& results, LogicalResult BatchToSpaceNDOp::verify() { BatchToSpaceNDOp op = *this; - auto block_shape_ty = op.getBlockShape().getType().cast(); - auto crops_ty = op.getCrops().getType().cast(); + auto block_shape_ty = mlir::cast(op.getBlockShape().getType()); + auto crops_ty = mlir::cast(op.getCrops().getType()); if (block_shape_ty.hasStaticShape() && crops_ty.hasStaticShape()) { const int block_rank = block_shape_ty.getShape().front(); @@ -617,9 +617,9 @@ LogicalResult BiasAddOp::verify() { return op.emitOpError("requires bias operand to have rank exactly one"); RankedTensorType value_ty = - op.getValue().getType().dyn_cast(); + mlir::dyn_cast(op.getValue().getType()); RankedTensorType bias_ty = - op.getBias().getType().dyn_cast(); + mlir::dyn_cast(op.getBias().getType()); if (!bias_ty || !value_ty) return success(); int64_t feature_dim_idx = @@ -716,7 +716,7 @@ OpFoldResult BroadcastToOp::fold(FoldAdaptor) { // Fold broadcast if operand and result types are the same and all dimensions // are statically known (no-op broadcast). - auto result_ty = getType().dyn_cast(); + auto result_ty = mlir::dyn_cast(getType()); if (!result_ty || !result_ty.hasStaticShape()) return {}; if (result_ty == input.getType()) return input; @@ -818,8 +818,8 @@ LogicalResult BroadcastGradientArgsOp::verify() { // Verify that output types are of rank one and matches the computed result // shape. - auto r0_ty = op.getR0().getType().dyn_cast(); - auto r1_ty = op.getR1().getType().dyn_cast(); + auto r0_ty = mlir::dyn_cast(op.getR0().getType()); + auto r1_ty = mlir::dyn_cast(op.getR1().getType()); if (r0_ty && r0_ty.hasStaticShape() && r0_ty.getDimSize(0) != r0.size()) return op.emitOpError() << "requires dimension 0 size of 'r0' to be " << r0.size() << " but got " << r0_ty.getShape()[0]; @@ -852,7 +852,8 @@ LogicalResult BroadcastGradientArgsOp::fold( auto build_out_dense_element = [](SmallVectorImpl& shape, Type input_type) { - Type element_type = input_type.cast().getElementType(); + Type element_type = + mlir::cast(input_type).getElementType(); RankedTensorType type = tensorflow::GetTypeFromTFTensorShape( {static_cast(shape.size())}, element_type); // Input could only be i32 or i64. For i32, downcast to int32_t array. @@ -893,7 +894,7 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite( int index = *branch.getValues().begin(); if (index < 0 || index >= op.num_branches()) index = op.num_branches() - 1; - auto func = op.getBranches()[index].cast(); + auto func = mlir::cast(op.getBranches()[index]); auto empty = rewriter.getStringAttr(""); ReplaceTfOpWithNewOp( rewriter, op, op.getResultTypes(), op.getOperands().drop_front(), func, @@ -932,7 +933,7 @@ static LogicalResult VerifyCaseOrIfOpBranchFunctions( for (const auto& branch : llvm::enumerate(branches)) { auto branch_func = symbol_table.lookupNearestSymbolFrom( - op, branch.value().cast()); + op, mlir::cast(branch.value())); if (!branch_func) return op->emitOpError() << "expects " << branch_name(branch.index()) << " (" @@ -1347,12 +1348,10 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( else return failure(); DenseElementsAttr const_attr; - auto scalar_tensor_type = - first_arg_op->getOperand(hoist_params->scalar_operand_idx) - .getType() - .dyn_cast(); + auto scalar_tensor_type = mlir::dyn_cast( + first_arg_op->getOperand(hoist_params->scalar_operand_idx).getType()); Type scalar_dtype = scalar_tensor_type.getElementType(); - if (scalar_dtype.isa()) + if (mlir::isa(scalar_dtype)) const_attr = DenseElementsAttr::get(scalar_tensor_type, static_cast(identity_val)); else @@ -1450,7 +1449,7 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams( } else { operand = arg.getDefiningOp()->getOperand(operand_idx); } - auto ranked = operand.getType().dyn_cast(); + auto ranked = mlir::dyn_cast(operand.getType()); return ranked && ranked.getRank() == (axis + 1) && ranked.getShape()[axis] == 1; }); @@ -1461,13 +1460,13 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams( return llvm::all_of(op.getValues(), [&](Value arg) -> bool { if (exceptions.count(arg)) return true; auto operand = arg.getDefiningOp()->getOperand(operand_idx); - auto ranked = operand.getType().dyn_cast(); + auto ranked = mlir::dyn_cast(operand.getType()); return ranked && ranked.hasRank() && ranked.getRank() == 0; }); }; // Concat result type must be a ranked tensor. - auto ranked = op.getType().dyn_cast(); + auto ranked = mlir::dyn_cast(op.getType()); if (!ranked) return std::nullopt; // TODO(ezhulenev): Add support for more valid concat patterns. @@ -1527,7 +1526,7 @@ static LogicalResult Verify(OpT op) { DenseIntElementsAttr axis_attr; if (matchPattern(op.getAxis(), m_Constant(&axis_attr))) { - auto input_ty = op.getX().getType().template dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getX().getType()); if (input_ty) { int64_t rank = input_ty.getRank(); assert(axis_attr.getNumElements() == 1 && @@ -1561,7 +1560,8 @@ LogicalResult ConcatOffsetOp::verify() { << "requires sizes of shapes and offsets to be the same, got sizes " << op.getShape().size() << " and " << op.getOffset().size(); - auto ranked_dim = op.getConcatDim().getType().dyn_cast(); + auto ranked_dim = + mlir::dyn_cast(op.getConcatDim().getType()); if (ranked_dim && ranked_dim.getRank() != 0) return op.emitOpError() << "requires concat_dim to be a scalar, got tensor of rank " @@ -1578,7 +1578,7 @@ LogicalResult ConcatOffsetOp::verify() { return op.emitOpError() << "requires operand and result " << idx << " to have compatible shapes"; - auto ranked_shape = shape.getType().dyn_cast(); + auto ranked_shape = mlir::dyn_cast(shape.getType()); if (!ranked_shape) continue; if (ranked_shape.getRank() != 1) @@ -1609,14 +1609,15 @@ LogicalResult ConcatOffsetOp::fold(FoldAdaptor adaptor, if (operands.size() < 3) return failure(); // Check concat_dim is a scalar. - auto concat_dim_attr = operands[0].dyn_cast_or_null(); + auto concat_dim_attr = + mlir::dyn_cast_or_null(operands[0]); if (!concat_dim_attr || concat_dim_attr.getType().getRank() != 0) return failure(); llvm::SmallVector shapes; shapes.reserve(operands.size() - 1); for (Attribute shape : llvm::drop_begin(operands, 1)) - if (auto shape_attr = shape.dyn_cast_or_null()) + if (auto shape_attr = mlir::dyn_cast_or_null(shape)) shapes.push_back(shape_attr); else return failure(); @@ -1685,14 +1686,14 @@ OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { void ConstOp::build(OpBuilder& builder, OperationState& result, Attribute value) { ShapedType type; - if (auto elem_attr = value.dyn_cast()) { + if (auto elem_attr = mlir::dyn_cast(value)) { return ConstOp::build(builder, result, elem_attr); - } else if (value.isa()) { + } else if (mlir::isa(value)) { // All TensorFlow types must be tensor types. In the build() method, // we want to provide more flexibility by allowing attributes of scalar // types. But we need to wrap it up with ElementsAttr to construct // valid TensorFlow constants. - auto typed_attr = value.cast(); + auto typed_attr = mlir::cast(value); type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, typed_attr.getType()); return ConstOp::build(builder, result, DenseElementsAttr::get(type, value)); @@ -1704,7 +1705,7 @@ void ConstOp::build(OpBuilder& builder, OperationState& result, void ConstOp::build(OpBuilder& builder, OperationState& result, Type type, Attribute value) { // Handle the case where the type and value are already tensors. - if (type.isa() && value.isa()) { + if (mlir::isa(type) && mlir::isa(value)) { result.addTypes(type); result.addAttribute("value", value); return; @@ -1722,7 +1723,7 @@ LogicalResult ConstOp::inferReturnTypes( ConstOpAdaptor adaptor(operands, attributes, properties, regions); auto value = adaptor.getValue(); if (!value) return emitOptionalError(location, "missing attribute 'value'"); - if (auto elem_attr = value.dyn_cast()) { + if (auto elem_attr = mlir::dyn_cast(value)) { inferredReturnTypes.assign({elem_attr.getType()}); return success(); } @@ -1743,7 +1744,7 @@ static LogicalResult VerifyConvOpAttributes( return emitOptionalError( location, "requires strides attribute length to be ", num_dims); auto is_not_positive = [](Attribute val) { - return val.cast().getValue().getSExtValue() <= 0; + return mlir::cast(val).getValue().getSExtValue() <= 0; }; if (llvm::any_of(strides, is_not_positive)) return emitOptionalError(location, "requires positive strides"); @@ -1793,9 +1794,8 @@ static LogicalResult Verify(OpT op) { if (padding == tensorflow::Padding::EXPLICIT) { ArrayRef explicit_padding; - ArrayAttr explicit_pad = - op->getAttr("explicit_paddings") - .template dyn_cast_or_null<::mlir::ArrayAttr>(); + ArrayAttr explicit_pad = mlir::dyn_cast_or_null<::mlir::ArrayAttr>( + op->getAttr("explicit_paddings")); if (!explicit_pad) { explicit_pad = ::mlir::Builder(op->getContext()).getI64ArrayAttr({}); } @@ -1812,7 +1812,7 @@ static LogicalResult Verify(OpT op) { num_dims * 2); } auto is_negative = [](Attribute val) { - return val.cast().getValue().getSExtValue() < 0; + return mlir::cast(val).getValue().getSExtValue() < 0; }; if (llvm::any_of(explicit_padding, is_negative)) return emitOptionalError(op.getLoc(), @@ -1827,7 +1827,7 @@ static LogicalResult Verify(OpT op) { } int64_t input_channels = ShapedType::kDynamic; - if (auto ty = op.getInput().getType().template dyn_cast()) { + if (auto ty = mlir::dyn_cast(op.getInput().getType())) { absl::string_view data_format(op.getDataFormat().data(), op.getDataFormat().size()); tensorflow::TensorFormat format; @@ -1838,8 +1838,7 @@ static LogicalResult Verify(OpT op) { } int64_t filter_channels = ShapedType::kDynamic; - if (auto ty = - op.getFilter().getType().template dyn_cast()) { + if (auto ty = mlir::dyn_cast(op.getFilter().getType())) { int idx = tensorflow::GetFilterTensorInputChannelsDimIndex( num_dims, tensorflow::FORMAT_HWIO); filter_channels = ty.getDimSize(idx); @@ -1891,8 +1890,8 @@ static LogicalResult inferConvReturnTypeComponents( const int64_t num_dims = 2 + num_spatial_dims; const Value input = op.getInput(); const Value filter = op.getFilter(); - const TensorType input_ty = input.getType().template cast(); - const TensorType filter_ty = filter.getType().template cast(); + const TensorType input_ty = mlir::cast(input.getType()); + const TensorType filter_ty = mlir::cast(filter.getType()); ArrayRef strides = op.getStrides().getValue(); StringRef data_format = op.getDataFormat(); @@ -1910,7 +1909,7 @@ static LogicalResult inferConvReturnTypeComponents( (void)padding_is_valid; auto get_int = [](Attribute attr) { - return attr.template cast().getInt(); + return mlir::cast(attr).getInt(); }; // Output always have `num_dims` rank. All dimensions are initialized to @@ -1967,7 +1966,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents( Conv2DOpAdaptor op(operands.getValues(), attributes, properties, regions); ArrayRef explicit_padding; ArrayAttr explicit_pad = - op.getExplicitPaddings().dyn_cast_or_null<::mlir::ArrayAttr>(); + mlir::dyn_cast_or_null<::mlir::ArrayAttr>(op.getExplicitPaddings()); if (!explicit_pad) { explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({}); } @@ -1984,7 +1983,7 @@ StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices& devices) { return getDataFormat(); // Input must be a tensor. - auto input_ty = getInput().getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(getInput().getType()); if (!input_ty) return getDataFormat(); // For f16 data type on devices with Tensor Cores support NHWC data format @@ -1998,7 +1997,7 @@ StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices& devices) { return getDataFormat(); // Keep current data format if filter rank is unknown or not equal to 4. - auto filter_ty = getFilter().getType().dyn_cast(); + auto filter_ty = mlir::dyn_cast(getFilter().getType()); if (!filter_ty || filter_ty.getRank() != 4) return getDataFormat(); const int64_t d0 = filter_ty.getDimSize(0); @@ -2006,7 +2005,7 @@ StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices& devices) { auto all_ones = [](ArrayAttr arr) -> bool { return llvm::all_of(arr, [](Attribute attr) -> bool { - return attr.cast().getInt() == 1; + return mlir::cast(attr).getInt() == 1; }); }; @@ -2068,7 +2067,7 @@ StringRef Conv2DBackpropFilterOp::GetOptimalLayout( return getDataFormat(); // Input must be a tensor. - auto input_ty = getInput().getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(getInput().getType()); if (!input_ty) return getDataFormat(); // For f16 data type on devices with Tensor Cores support NHWC data format @@ -2142,7 +2141,7 @@ StringRef Conv2DBackpropInputOp::GetOptimalLayout( return getDataFormat(); // Filter must be a tensor. - auto filter_ty = getFilter().getType().dyn_cast(); + auto filter_ty = mlir::dyn_cast(getFilter().getType()); if (!filter_ty) return getDataFormat(); // For f16 data type on devices with Tensor Cores support NHWC data format @@ -2177,7 +2176,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents( LogicalResult DataFormatVecPermuteOp::verify() { DataFormatVecPermuteOp op = *this; - auto input_ty = op.getX().getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getX().getType()); if (!input_ty) return success(); int rank = input_ty.getRank(); @@ -2285,12 +2284,12 @@ class DivNoNanOrMulNoNanConstantY : public OpRewritePattern { if (auto yDefOp = dyn_cast_or_null(y.getDefiningOp())) { Type typeOfElementsInY = getElementTypeOrSelf(y.getType()); ElementsAttr attr = yDefOp.getValue(); - bool yHasComplexElements = typeOfElementsInY.isa(); + bool yHasComplexElements = mlir::isa(typeOfElementsInY); // If `y` is a splat constant, then the op will definitely get replaced. // We check for a splat constant first, in order to optimize the // performance of this canonicalization because this check will be O(1). - if (auto splatAttr = attr.dyn_cast()) { + if (auto splatAttr = mlir::dyn_cast(attr)) { bool splatAttrIsZero = false; if (!yHasComplexElements) { if (splatAttr.getSplatValue().isZero()) @@ -2356,7 +2355,8 @@ LogicalResult DynamicStitchOp::verify() { if (op.getN() < 1) return op.emitOpError("requires attribute N with value >= 1"); - if (RankedTensorType out_ty = op.getType().dyn_cast()) { + if (RankedTensorType out_ty = + mlir::dyn_cast(op.getType())) { if (out_ty.getRank() == 0) { return op.emitOpError("requires non scalar output"); } @@ -2383,8 +2383,9 @@ LogicalResult DynamicStitchOp::verify() { } Value data = std::get<1>(it); - RankedTensorType index_ty = index.getType().dyn_cast(); - RankedTensorType data_ty = data.getType().dyn_cast(); + RankedTensorType index_ty = + mlir::dyn_cast(index.getType()); + RankedTensorType data_ty = mlir::dyn_cast(data.getType()); if (!index_ty || !data_ty) continue; int64_t index_rank = index_ty.getRank(); @@ -2429,7 +2430,7 @@ LogicalResult DynamicStitchOp::verify() { expected_shape.append(inferred_item_shape->begin(), inferred_item_shape->end()); - auto out_ty = op.getType().cast(); + auto out_ty = mlir::cast(op.getType()); auto expected_out_ty = tensorflow::GetTypeFromTFTensorShape( expected_shape, out_ty.getElementType()); @@ -2471,25 +2472,25 @@ OpFoldResult EmptyOp::fold(FoldAdaptor adaptor) { Attribute attr = operands.front(); if (!attr) return {}; - auto int_attr = attr.cast(); + auto int_attr = mlir::cast(attr); SmallVector out_shape; for (const auto val : int_attr.getValues()) { out_shape.push_back(val); } - auto type = getResult().getType().cast(); + auto type = mlir::cast(getResult().getType()); auto etype = type.getElementType(); // We can not fold if the result is not static. if (!type.hasStaticShape()) return {}; - if (auto float_type = etype.dyn_cast()) { + if (auto float_type = mlir::dyn_cast(etype)) { auto out_type = tensorflow::GetTypeFromTFTensorShape(out_shape, float_type); return DenseElementsAttr::get(out_type, {APFloat(float_type.getFloatSemantics())}); } - if (auto int_type = etype.dyn_cast()) { + if (auto int_type = mlir::dyn_cast(etype)) { auto out_type = tensorflow::GetTypeFromTFTensorShape(out_shape, etype); APInt val(int_type.getWidth(), 0, int_type.getSignedness()); return DenseElementsAttr::get(out_type, val); @@ -2580,7 +2581,7 @@ EnqueueTPUEmbeddingSparseTensorBatchOp::GetResourceInstanceStr() { //===----------------------------------------------------------------------===// OpFoldResult EnsureShapeOp::fold(FoldAdaptor) { - ShapedType type = getInput().getType().dyn_cast(); + ShapedType type = mlir::dyn_cast(getInput().getType()); if (!type || !type.hasRank()) return {}; // If shape attribute equals input operand's type's shape, fold it to input. std::optional> shape_constraint = getShape(); @@ -2639,15 +2640,15 @@ static LogicalResult flipComatibleShapeError(Ty op, PatternRewriter& rewriter) { // we don't know which one it is. TF shape inference turns unranked outputs // into ranked ones if it can statically evaluate the broadcast, see the shape // function of tf.Equal. - auto ty = op.getType().template dyn_cast(); + auto ty = mlir::dyn_cast(op.getType()); if (!ty) { return rewriter.notifyMatchFailure(op, "requires a ranked output shape"); } // Unless this is a scalar compare, a scalar output indicates that this will // always fail. - auto x_ty = op.getX().getType().template dyn_cast(); - auto y_ty = op.getY().getType().template dyn_cast(); + auto x_ty = mlir::dyn_cast(op.getX().getType()); + auto y_ty = mlir::dyn_cast(op.getY().getType()); if (ty.getRank() == 0 && (!x_ty || x_ty.getRank() != 0 || !y_ty || y_ty.getRank() != 0)) { return rewriter.notifyMatchFailure(op, "output rank must match input rank"); @@ -2675,10 +2676,10 @@ void NotEqualOp::getCanonicalizationPatterns(RewritePatternSet& results, //===----------------------------------------------------------------------===// Type InferExpandDimsOpType(Value input, Value dim) { - Type element_ty = input.getType().cast().getElementType(); + Type element_ty = mlir::cast(input.getType()).getElementType(); auto unranked_ty = UnrankedTensorType::get(element_ty); - auto input_ty = input.getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(input.getType()); if (!input_ty) return unranked_ty; DenseIntElementsAttr dim_attr; @@ -2773,7 +2774,7 @@ LogicalResult FakeQuantWithMinMaxVarsPerChannelOp::verify() { "requires num_bits to be between 2 and 16, inclusive"); } - auto inputs_type = inputs.getType().dyn_cast(); + auto inputs_type = mlir::dyn_cast(inputs.getType()); if (!inputs_type) return success(); int depth = inputs_type.getDimSize(inputs_type.getRank() - 1); if ((min && min.getDimSize(0) != depth) || @@ -2800,7 +2801,7 @@ LogicalResult FillOp::verify() { } static ShapedType InferFillOpType(Value dims, Value value) { - Type etype = value.getType().cast().getElementType(); + Type etype = mlir::cast(value.getType()).getElementType(); DenseIntElementsAttr dims_attr; if (matchPattern(dims, m_Constant(&dims_attr))) { @@ -2813,7 +2814,7 @@ static ShapedType InferFillOpType(Value dims, Value value) { } if (auto shape_op = dims.getDefiningOp()) { - if (auto t = shape_op.getInput().getType().dyn_cast()) { + if (auto t = mlir::dyn_cast(shape_op.getInput().getType())) { return t; } } @@ -2830,20 +2831,20 @@ OpFoldResult FillOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); assert(operands.size() == 2 && "fill op has two operand"); - auto type = getType().cast(); + auto type = mlir::cast(getType()); // DenseElementsAttr that is used in this folder only supports int and float // types. // TODO(hinsu): Handle complex types once there is a attribute kind for // complex. if (!type.getElementType().isIntOrFloat()) return {}; - auto value = operands[1].dyn_cast_or_null(); + auto value = mlir::dyn_cast_or_null(operands[1]); if (!value) return {}; if (type.hasStaticShape()) return DenseElementsAttr::get(type, value.getValues()[0]); - auto dims = operands[0].dyn_cast_or_null(); + auto dims = mlir::dyn_cast_or_null(operands[0]); if (!dims) return {}; llvm::SmallVector shape; @@ -2876,7 +2877,7 @@ StringRef FusedBatchNormGradV3Op::GetOptimalLayout( // For f16 data type on devices with Tensor Cores support NHWC data format // is up to ~2x faster. - auto x_ty = getX().getType().cast(); + auto x_ty = mlir::cast(getX().getType()); const bool is_f16 = x_ty.getElementType().isF16(); if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; @@ -2940,7 +2941,7 @@ static StringRef GetOptimalLayout(const RuntimeDevices& devices, Op* op) { // For f16 data type on devices with Tensor Cores support NHWC data format // is up to ~2x faster. - auto x_ty = op->getX().getType().template cast(); + auto x_ty = mlir::cast(op->getX().getType()); const bool is_f16 = x_ty.getElementType().isF16(); if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; @@ -3045,7 +3046,7 @@ void GeneratorDatasetRegionOp::getSuccessorRegions( LogicalResult GatherV2Op::verify() { GatherV2Op op = *this; int64_t batch_dims = op.getBatchDims(); - if (auto ty = op.getIndices().getType().dyn_cast()) { + if (auto ty = mlir::dyn_cast(op.getIndices().getType())) { int64_t rank = ty.getRank(); if (batch_dims > rank || batch_dims < -rank) return op.emitOpError() @@ -3060,7 +3061,7 @@ LogicalResult GatherV2Op::verify() { DenseIntElementsAttr axis_attr; if (matchPattern(op.getAxis(), m_Constant(&axis_attr))) { int64_t axis = (*axis_attr.begin()).getSExtValue(); - if (auto ty = op.getParams().getType().dyn_cast()) { + if (auto ty = mlir::dyn_cast(op.getParams().getType())) { int64_t rank = ty.getRank(); if (axis >= rank || axis < -rank) return op.emitOpError() << "axis (" << axis << ") must be in range [" @@ -3283,7 +3284,7 @@ void IfRegionOp::getSuccessorRegions( // Verifies that the input is 1D. LogicalResult InvertPermutationOp::verify() { InvertPermutationOp op = *this; - auto x_type = op.getX().getType().cast(); + auto x_type = mlir::cast(op.getX().getType()); if (!x_type.hasRank()) return success(); if (x_type.getShape().size() != 1) return op.emitOpError() << "requires input x to be 1-dimensional"; @@ -3310,10 +3311,12 @@ OpFoldResult LeakyReluOp::fold(FoldAdaptor adaptor) { return FloatAttr::get(arg.getType(), val); }; - if (auto arg = operands[0].dyn_cast_or_null()) { + if (auto arg = mlir::dyn_cast_or_null(operands[0])) { return calculate(arg); - } else if (auto arg = operands[0].dyn_cast_or_null()) { - if (auto elementAttr = arg.getSplatValue().dyn_cast()) + } else if (auto arg = + mlir::dyn_cast_or_null(operands[0])) { + if (auto elementAttr = + mlir::dyn_cast(arg.getSplatValue())) return DenseElementsAttr::get(arg.getType(), calculate(elementAttr)); } return {}; @@ -3378,7 +3381,7 @@ OpFoldResult LogicalAndOp::fold(FoldAdaptor adaptor) { auto result_type = getType(); for (const auto& operand : operands) { - auto splat_attr = operand.dyn_cast_or_null(); + auto splat_attr = mlir::dyn_cast_or_null(operand); if (!splat_attr) continue; if (splat_attr.getType() != result_type) continue; @@ -3540,7 +3543,8 @@ LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef permutation) { dyn_cast_or_null(getReductionIndices().getDefiningOp()); if (!reduction_op) return failure(); - auto reductions_value = reduction_op.getValue().dyn_cast(); + auto reductions_value = + mlir::dyn_cast(reduction_op.getValue()); if (!reductions_value) return failure(); // Prepare new reduction indices according to operand permutation. @@ -3597,8 +3601,8 @@ void HashTableOp::getCanonicalizationPatterns(RewritePatternSet& results, LogicalResult BitcastOp::verify() { BitcastOp op = *this; - auto input_type = op.getInput().getType().cast(); - auto output_type = op.getOutput().getType().cast(); + auto input_type = mlir::cast(op.getInput().getType()); + auto output_type = mlir::cast(op.getOutput().getType()); auto input_element_type = input_type.getElementType(); auto output_element_type = output_type.getElementType(); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.cc index d67c1da227d1c6..b3ce501c1c08d1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h" +#include "mlir/Support/LLVM.h" // from @llvm-project + namespace mlir { namespace TF { @@ -60,7 +62,7 @@ ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef permutation, // Shuffle ranked tensor dimensions according to the permutation. Type ShuffleRankedTensorType(Type type, ArrayRef permutation) { - if (auto ranked_type = type.dyn_cast()) { + if (auto ranked_type = mlir::dyn_cast(type)) { ArrayRef shape = ranked_type.getShape(); assert(permutation.size() == shape.size()); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 3dfe56bc625f5a..c0528f35bd11fe 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -28,6 +28,7 @@ limitations under the License. #include #include #include +#include #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" @@ -103,6 +104,11 @@ Value LookThroughIdentity(Value result) { return result; } +bool IsWithinInt32Range(int64_t value) { + return (value >= std::numeric_limits::min() && + value <= std::numeric_limits::max()); +}; + #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" } // namespace @@ -416,10 +422,27 @@ struct ConvertPackToReshape : public OpRewritePattern { return failure(); } - // Create constant shape for reshape. - auto type = tensorflow::GetTypeFromTFTensorShape( + auto output_int_type = tensorflow::GetTypeFromTFTensorShape( output_ty.getRank(), rewriter.getIntegerType(64)); - auto shape_attr = DenseIntElementsAttr::get(type, output_ty.getShape()); + auto shape_attr = + DenseIntElementsAttr::get(output_int_type, output_ty.getShape()); + + // use int32_t instead of int64_t if all elements are in the range of int32 + // because int64 is not supported in dynamic reshape in XLA + bool elements_all_in_int32_range = + std::all_of(output_ty.getShape().begin(), output_ty.getShape().end(), + IsWithinInt32Range); + + if (elements_all_in_int32_range) { + std::vector output_shape(output_ty.getRank()); + std::transform(output_ty.getShape().begin(), output_ty.getShape().end(), + output_shape.begin(), + [](int64_t val) { return static_cast(val); }); + output_int_type = tensorflow::GetTypeFromTFTensorShape( + output_ty.getRank(), rewriter.getIntegerType(32)); + shape_attr = DenseIntElementsAttr::get(output_int_type, output_shape); + } + auto shape = rewriter.create(pack_op.getLoc(), shape_attr); // TODO(b/173622615): Remove after fixed. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.cc index 24036d17b588e6..ca8f27a1489c06 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TF { @@ -33,9 +34,9 @@ class IdentityNOp; RankedTensorType GetRankedTensorTypeForOperand(Value operand) { DenseElementsAttr attr; if (matchPattern(operand, m_Constant(&attr))) { - return attr.getType().dyn_cast(); + return mlir::dyn_cast(attr.getType()); } - return operand.getType().dyn_cast(); + return mlir::dyn_cast(operand.getType()); } // Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If @@ -53,7 +54,7 @@ Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x, Value y, } } - auto ranked_type = result_type.dyn_cast(); + auto ranked_type = mlir::dyn_cast(result_type); if (!ranked_type) return UnrankedTensorType::get(builder->getI1Type()); return RankedTensorType::get(ranked_type.getShape(), builder->getI1Type()); @@ -65,7 +66,7 @@ Type InferReductionOpType(Value input, Value reduction_indices, Type element_ty = getElementTypeOrSelf(input_ty); // Output type is unranked if input type is not ranked. - auto ranked_ty = input_ty.dyn_cast(); + auto ranked_ty = mlir::dyn_cast(input_ty); if (!ranked_ty) return UnrankedTensorType::get(element_ty); int64_t rank = ranked_ty.getRank(); @@ -124,7 +125,7 @@ LogicalResult VerifyTypesCompatibility(Operation::operand_type_range types, // the dimension index on the first mismatch and ignore dimension at that // index in following types. for (Type ty : types) { - RankedTensorType ranked_ty = ty.dyn_cast(); + RankedTensorType ranked_ty = mlir::dyn_cast(ty); if (!ranked_ty) continue; int64_t rank = ranked_ty.getRank(); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h index aaf795afd72917..e77ea7d77deef0 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { @@ -36,7 +37,7 @@ RankedTensorType GetRankedTensorTypeForOperand(Value operand); // given `rank`. inline bool IsOfRankedFloatTensorType(RankedTensorType type, int rank) { return type && type.getRank() == rank && - type.getElementType().isa(); + mlir::isa(type.getElementType()); } // Returns true if the given `value` has the specified rank or has unranked diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index e5ecef28a38377..45717471e373a2 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -46,11 +47,11 @@ namespace tf_saved_model { //===----------------------------------------------------------------------===// static bool IsStrArrayAttr(Attribute attr) { - auto array = attr.dyn_cast(); + auto array = mlir::dyn_cast(attr); if (!array) return false; - return llvm::all_of(array, - [](Attribute attr) { return attr.isa(); }); + return llvm::all_of( + array, [](Attribute attr) { return mlir::isa(attr); }); } //===----------------------------------------------------------------------===// @@ -58,10 +59,11 @@ static bool IsStrArrayAttr(Attribute attr) { //===----------------------------------------------------------------------===// LogicalResult VerifyTensorTypesCompatible(Type t1, Type t2) { - if (!t1.isa() || !t2.isa()) { + if (!mlir::isa(t1) || !mlir::isa(t2)) { return failure(); } - return verifyCompatibleShape(t1.cast(), t2.cast()); + return verifyCompatibleShape(mlir::cast(t1), + mlir::cast(t2)); } LogicalResult GlobalTensorOp::verify() { @@ -75,7 +77,7 @@ LogicalResult GlobalTensorOp::verify() { } } if (!global_tensor.getIsMutable()) { - if (!global_tensor.getType().cast().hasStaticShape()) { + if (!mlir::cast(global_tensor.getType()).hasStaticShape()) { return global_tensor.emitError() << "'type' attribute for immutable 'tf_saved_model.global_tensor' " "should have a static shape"; @@ -91,7 +93,7 @@ LogicalResult SessionInitializerOp::verify() { for (auto sym_ref : session_initializer.getInitializers()) { auto init_func_op = symbol_table.lookup( - sym_ref.cast().getValue()); + mlir::cast(sym_ref).getValue()); if (!init_func_op) return session_initializer.emitOpError() @@ -143,16 +145,16 @@ TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context) } static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) { - auto attr = named_attr.getValue().dyn_cast(); + auto attr = mlir::dyn_cast(named_attr.getValue()); if (!attr) { return op->emitError() << "'" << kTfSavedModelIndexPathAttr << "' attribute should be an ArrayAttr"; } for (auto element : attr) { - if (element.isa()) { + if (mlir::isa(element)) { continue; } - if (auto integer = element.dyn_cast()) { + if (auto integer = mlir::dyn_cast(element)) { if (integer.getValue().getBitWidth() == 64) { continue; } @@ -165,7 +167,7 @@ static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) { Type GetBoundInputArgTypeFor(mlir::Operation *op) { if (auto global_tensor = llvm::dyn_cast(op)) { - auto type = global_tensor.getType().cast(); + auto type = mlir::cast(global_tensor.getType()); return RankedTensorType::get( {}, TF::ResourceType::get({type}, type.getContext())); } @@ -196,12 +198,12 @@ LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute( Operation *op, unsigned region_index, unsigned arg_index, NamedAttribute named_attr) { if (named_attr.getName() == "tf_saved_model.bound_input") { - if (!named_attr.getValue().isa()) { + if (!mlir::isa(named_attr.getValue())) { return op->emitError() << "'tf_saved_model.bound_input' attribute should " "be a FlatSymbolRefAttr"; } auto symbol_name = - named_attr.getValue().cast().getValue(); + mlir::cast(named_attr.getValue()).getValue(); auto module = op->getParentOfType(); mlir::Operation *symbol_op = module.lookupSymbol(symbol_name); if (!symbol_op) { @@ -292,8 +294,8 @@ static LogicalResult VerifySavedModelModule( &op, {exported_names_ident, attr}))) { return failure(); } - for (auto str : attr.cast()) { - auto exported_name = str.cast().getValue(); + for (auto str : mlir::cast(attr)) { + auto exported_name = mlir::cast(str).getValue(); auto p = exported_name_to_op.insert({exported_name, &op}); if (!p.second) { return op.emitError() @@ -341,7 +343,8 @@ static LogicalResult VerifySavedModelModule( auto init_syms = (*session_initializers.begin()).getInitializers(); return std::any_of( init_syms.begin(), init_syms.end(), [&](Attribute sym_ref) { - return sym_ref.cast().getValue() == func.getName(); + return mlir::cast(sym_ref).getValue() == + func.getName(); }); }; @@ -439,7 +442,7 @@ LogicalResult VerifyInitializerTypeAttr(Operation *op, // Validate the attribute value. auto initializer_type_attr_value = - named_attr.getValue().dyn_cast_or_null(); + mlir::dyn_cast_or_null(named_attr.getValue()); if (!initializer_type_attr_value) { return op->emitError() << "Attribute tf_saved_model.initializer_type " << "should be a StringAttr."; @@ -504,7 +507,7 @@ SmallVector GetExportedNames(Operation *op) { op->getAttrOfType(kTfSavedModelExportedNamesAttr); if (exported_names) { for (auto name : exported_names) { - ret.push_back(name.cast().getValue()); + ret.push_back(mlir::cast(name).getValue()); } } return ret; @@ -547,7 +550,7 @@ class OptimizeSessionInitializerPattern SmallVector to_keep; for (auto sym_ref : op.getInitializers()) { auto init_func_op = symbol_table.lookup( - sym_ref.cast().getValue()); + mlir::cast(sym_ref).getValue()); // The init function can only be referenced from the SessionInitializerOp. // And there is at most one SessionInitializerOp in the module. So if both @@ -590,7 +593,7 @@ SmallVector GetSessionInitializerExportedName(ModuleOp op) { SmallVector results; for (auto sym_ref : session_initializer_op.getInitializers()) { auto init_func_op = symbol_table.lookup( - sym_ref.cast().getValue()); + mlir::cast(sym_ref).getValue()); auto exported_names = GetExportedNames(init_func_op); assert(exported_names.size() == 1); results.push_back(exported_names[0]); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h index 62f6192c1f84f0..c6abd7689beddc 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -38,7 +39,7 @@ namespace TF { static inline LogicalResult VerifyRefTypeMatch(mlir::Type type, mlir::Type maybe_ref_type) { if (auto ref_type = - maybe_ref_type.dyn_cast()) + mlir::dyn_cast(maybe_ref_type)) return success(ref_type.RemoveRef().getTypeID() == type.getTypeID()); return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 04052926504174..a5925ac4156baa 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -1924,8 +1924,8 @@ func.func @testFoldEnsureShapeOp(%arg0: tensor<10x20xf32>) -> (tensor<10x20xf32> func.func @testConvertPackToReshapeAxis0(%arg0: tensor<2x3xf32>) -> tensor<1x2x3xf32> { %0 = "tf.Pack"(%arg0) {axis = 0 : i64, _xla_outside_compilation = "1", device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>) -> tensor<1x2x3xf32> func.return %0 : tensor<1x2x3xf32> - // CHECK: %[[SHAPE:.*]] = "tf.Const"() <{value = dense<[1, 2, 3]> : tensor<3xi64>}> : () -> tensor<3xi64> - // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) {_xla_outside_compilation = "1", device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>, tensor<3xi64>) -> tensor<1x2x3xf32> + // CHECK: %[[SHAPE:.*]] = "tf.Const"() <{value = dense<[1, 2, 3]> : tensor<3xi32>}> : () -> tensor<3xi32> + // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) {_xla_outside_compilation = "1", device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>, tensor<3xi32>) -> tensor<1x2x3xf32> // CHECK: return %[[RESHAPE]] : tensor<1x2x3xf32> } @@ -1933,8 +1933,8 @@ func.func @testConvertPackToReshapeAxis0(%arg0: tensor<2x3xf32>) -> tensor<1x2x3 func.func @testConvertPackToReshapeAxis1(%arg0: tensor<2x3xf32>) -> tensor<2x1x3xf32> { %0 = "tf.Pack"(%arg0) {axis = 1 : i64, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>) -> tensor<2x1x3xf32> func.return %0 : tensor<2x1x3xf32> - // CHECK: %[[SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 1, 3]> : tensor<3xi64>}> : () -> tensor<3xi64> - // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>, tensor<3xi64>) -> tensor<2x1x3xf32> + // CHECK: %[[SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 1, 3]> : tensor<3xi32>}> : () -> tensor<3xi32> + // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>, tensor<3xi32>) -> tensor<2x1x3xf32> // CHECK: return %[[RESHAPE]] : tensor<2x1x3xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir index 90f1cfc2fd5027..bafe05c2eebcc4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir @@ -105,20 +105,18 @@ func.func @cluster_operands(%arg0: tensor) -> tensor { // ----- // Tests cluster attributes are copied over to cluster_func. -// Includes device info propagation. // CHECK-LABEL: func @cluster_attrs func.func @cluster_attrs() -> tensor { %0 = "tf_device.cluster"() ({ %1 = "tf.A"() : () -> tensor tf_device.return %1 : tensor - }) {cluster_attr = "cluster_attr", device = "device"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor func.return %0 : tensor } // CHECK: "tf_device.cluster_func" // CHECK-SAME: cluster_attr = "cluster_attr" -// CHECK-SAME: device = "device" // ----- diff --git a/tensorflow/compiler/mlir/tensorflow/tests/convert_to_legacy_compile_and_replicate_attributes.mlir b/tensorflow/compiler/mlir/tensorflow/tests/convert_to_legacy_compile_and_replicate_attributes.mlir index bc900be065c4e8..e27ebb2ea5189c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/convert_to_legacy_compile_and_replicate_attributes.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/convert_to_legacy_compile_and_replicate_attributes.mlir @@ -37,7 +37,7 @@ func.func @convert_to_legacy_attributes_failure(%arg0: tensor<*xf32>, %arg1: ten %outputs_9, %control_10 = tf_executor.island(%control_4) wraps "tf.Identity"(%outputs_7) {_replication_info = "cluster", _tpu_input_identity = true, _xla_compile_device_type = "TPU", device = ""} : (tensor<*xf32>) -> tensor<*xf32> %outputs_11, %control_12 = tf_executor.island wraps "tf.Mul"(%outputs, %outputs_9) {_replication_info = "cluster", _xla_compile_device_type = "TPU", device = ""} : (tensor, tensor<*xf32>) -> tensor<*xf32> %outputs_13, %control_14 = tf_executor.island wraps "tf.AddV2"(%outputs_11, %outputs_0) {_replication_info = "cluster", _xla_compile_device_type = "TPU", device = ""} : (tensor<*xf32>, tensor) -> tensor<*xf32> - // expected-error @+1 {{'tf.Identity' op has '_replication_info' attribute but not '_xla_compile_device_type' attribute which is unsupported}} + // expected-error @+1 {{'tf.Identity' op is expected to have either both or none of '_replication_info' and '_xla_compile_device_type' attributes}} %outputs_15, %control_16 = tf_executor.island wraps "tf.Identity"(%outputs_13) {_replication_info = "cluster", _tpu_output_identity = true, device = "/device:TPU_REPLICATED_CORE:0"} : (tensor<*xf32>) -> tensor<*xf32> %outputs_17, %control_18 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%outputs_15) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> %outputs_19, %control_20 = tf_executor.island(%control_3) wraps "tf.Identity"(%outputs_17) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir index ca1e4c99549d94..9abb90805961c3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir @@ -595,3 +595,17 @@ func.func @unsupported_op_gpu_cluster() -> tensor { }) {allow_soft_placement = true, _xla_compile_device_type = "GPU"} : () -> tensor func.return %0 : tensor } + +// CHECK-LABEL: func @xla_host_compute +func.func @xla_host_compute(%arg0: tensor) { + "tf_device.cluster"() ({ + %cst = "tf.Const"() {value = dense<16> : tensor} : () -> tensor + // CHECK: tf.XlaHostCompute + // CHECK-SAME:_xla_original_oc_node_name = "hcb0", _xla_token_input_nodes = ["_xla_token_arg_node"] + "tf.XlaHostCompute"(%cst) <{ancestors = [], cost_estimate_ns = 1000000 : i64, key = "_host_callback", recv_key = "", send_key = "", shapes = [], tpu_core = 0 : i64}> {_xla_original_oc_node_name = "hcb0", _xla_token_input_nodes = ["_xla_token_arg_node"]} : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + func.return +} + + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple_tf_dialect_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple_tf_dialect_op.mlir deleted file mode 100644 index 780406e0c16127..00000000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/simple_tf_dialect_op.mlir +++ /dev/null @@ -1,32 +0,0 @@ -// RUN: tf-mlir-translate -test-only-mlir-to-tf-nodedef %s -o - | FileCheck %s - -func.func @main() { -^bb0: - // CHECK: name: "node_name" - // CHECK-NEXT: op: "Const" - // CHECK-NEXT: attr { - // CHECK: key: "dtype" - // CHECK-NEXT: value { - // CHECK-NEXT: type: DT_INT32 - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: attr { - // CHECK-NEXT: key: "value" - // CHECK-NEXT: value { - // CHECK-NEXT: tensor { - // CHECK-NEXT: dtype: DT_INT32 - // CHECK-NEXT: tensor_shape { - // CHECK-NEXT: dim { - // CHECK-NEXT: size: 2 - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: tensor_content: "\200\000\000\000\200\000\000\000" - // CHECK: experimental_debug_info { - // CHECK-NEXT: original_node_names: "n1" - // CHECK-NEXT: original_func_names: "f1" - // CHECK-NEXT: } - %0 = "tf.Const"() {value = #tf_type : tensor<2xi32>} : () -> (tensor<2xi32>) loc(fused[callsite("n1@f1" at callsite("node_name" at "file_loc"))]) - func.return -} - - diff --git a/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-executor.mlir b/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-executor.mlir deleted file mode 100644 index ddc9ea80d37036..00000000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-executor.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=MlirRoundtripPass | FileCheck %s - -// The test uses the tf_graph_optimization_pass to run the MlirRoundtripPass. -// We convert mlir -> Graph -> mlir -> Graph -> mlir - -func.func @main() { - tf_executor.graph { - %0 = tf_executor.island wraps "tf.NoOp"() {} : () -> () loc("X") - tf_executor.fetch - } - func.return -} - -// Check for the presence of tf.NoOp in the final output. -// CHECK: tf.NoOp diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference_with_shape_specialization.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference_with_shape_specialization.mlir new file mode 100644 index 00000000000000..b016e3ee033012 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference_with_shape_specialization.mlir @@ -0,0 +1,42 @@ +// RUN: tf-opt %s -tf-shape-inference=input-arg-shapes=1 -verify-diagnostics -split-input-file | FileCheck %s +// RUN: not tf-opt %s -tf-shape-inference=input-arg-shapes=* 2>&1 | FileCheck --check-prefix=INPUT_ARG_SHAPES_ERROR %s +// INPUT_ARG_SHAPES_ERROR: Missing input argument shapes + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + // CHECK-LABEL: func.func @main + // CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor + // CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{value = dense<3> : tensor}> : () -> tensor + // CHECK-NEXT: %[[UQ:.*]] = "tf.UniformQuantize"(%arg0, %cst, %cst_0) <{quantization_axis = -1 : i64, quantization_max_val = 127 : i64, quantization_min_val = -128 : i64}> : (tensor<1xf32>, tensor, tensor) -> tensor<1x!tf_type.qint8> + // CHECK-NEXT: %[[UDQ:.*]] = "tf.UniformDequantize"(%[[UQ]], %[[CST_0]], %[[CST_1]]) <{quantization_axis = -1 : i64, quantization_max_val = 127 : i64, quantization_min_val = -128 : i64}> : (tensor<1x!tf_type.qint8>, tensor, tensor) -> tensor<1xf32> + // CHECK-NEXT: return %[[UDQ]] : tensor<1xf32> + func.func @main(%arg0 : tensor) -> tensor { + %scales = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor + %zps = "tf.Const"() { value = dense<3> : tensor } : () -> tensor + + %0 = "tf.UniformQuantize"(%arg0, %scales, %zps) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor, tensor, tensor) -> tensor + %1 = "tf.UniformDequantize"(%0, %scales, %zps) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor, tensor, tensor) -> tensor + func.return %1 : tensor + } +} + +// ----- + +// expected-error@+1 {{Input shapes provided but no `main` function found.}} +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @non_main(%arg0 : tensor) -> tensor { + %scales = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor + %zps = "tf.Const"() { value = dense<3> : tensor } : () -> tensor + + %0 = "tf.UniformQuantize"(%arg0, %scales, %zps) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor, tensor, tensor) -> tensor + %1 = "tf.UniformDequantize"(%0, %scales, %zps) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor, tensor, tensor) -> tensor + func.return %1 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir index c4ca46d34d7501..4ebc7cb4d063a2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir @@ -62,3 +62,30 @@ func.func @multiple_result_user(%arg0: tensor, %arg1: tensor<*x!tf_type.res func.func @multiple_result_user_func(%arg0: tensor) -> tensor { func.return %arg0 : tensor } + +// CHECK-LABEL: @reads_outside_replicate_op +func.func @reads_outside_replicate_op(%arg0: tensor<*x!tf_type.resource>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) { +// CHECK-COUNT-1: tf.ReadVariableOp +// CHECK: tf_device.replicate +// CHECK-NOT: tf.ReadVariableOp + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor<1xf32> + %cst = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64> + %fill = "tf.Fill"(%cst_0, %cst) : (tensor<1xi64>, tensor) -> tensor<1xf32> + tf_device.replicate([%0, %fill] as %arg_r0: tensor<1xf32>) {n = 2 : i32} { + %1 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ + %2 = "tf.Identity"(%arg_r0) : (tensor<1xf32>) -> tensor<1xf32> + tf_device.return %2 : tensor<1xf32> + }) : () -> tensor<1xf32> + %3 = "tf_device.cluster_func"(%1) <{func = @write_chain_func}> {_replication_info = "cluster__train_helper", _xla_compile_device_type = "TPU", num_cores_per_replica = 1 : i64} : (tensor<1xf32>) -> tensor<1xf32> + "tf.AssignVariableOp"(%arg0, %3) <{validate_shape = false}> : (tensor<*x!tf_type.resource>>, tensor<1xf32>) -> () + tf_device.return + } + func.return +} + +func.func private @write_chain_func(%arg0: tensor<1xf32>) -> (tensor<1xf32>) { + %cst = "tf.Const"() <{value = dense<[[0, 1]]> : tensor<1x2xi32>}> : () -> tensor<1x2xi32> + %0 = "tf.XlaAllReduce"(%arg0, %cst) <{mode = "CrossReplica", reduce_op = "Add"}> : (tensor<1xf32>, tensor<1x2xi32>) -> tensor<1xf32> + return %0 : tensor<1xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir index 1d3c1b6f3cf518..4ede37c3d1b4d2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir @@ -724,8 +724,8 @@ func.func @cluster_ops_keep_replicated_core_attr() { // ----- -func.func @missing_compilation_attribute() { - // expected-error@+1 {{'tf.opA' op has '_replication_info' attribute but not '_xla_compile_device_type' attribute which is unsupported}} +func.func @missing_replication_or_compilation_attribute() { + // expected-error@+1 {{'tf.opA' op is expected to have either both or none of '_replication_info' and '_xla_compile_device_type' attributes}} %0 = "tf.opA"() { _replication_info = "replicate", device = "/device:TPU_REPLICATED_CORE:0", name = "name", is_stateless = true} : () -> tensor "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", num_replicas = 1, topology = "topology"} : () -> () func.return @@ -744,124 +744,17 @@ func.func @empty_replication_attribute() { func.func @invalid_device_type() { // expected-error@+1 {{'tf.opA' op has invalid '_xla_compile_device_type' value 'XPU'}} - "tf.opA"() { _xla_compile_device_type = "XPU", _replication_info = "replicate", is_stateless = true} : () -> () - func.return -} - -// ----- - -// Check non-replicated case, including expected attributes at device cluster. -// CHECK: "tf_device.cluster"() -// CHECK: "tf.opA"() -// CHECK: "tf.opB"() -// CHECK: tf_device.return -// CHECK: }) {_replication_info = "__no_replication_cluster", _xla_compile_device_type = "TPU", allow_soft_placement = true, device_assignment = [], num_cores_per_replica = 1 : i32, step_marker_location = "", topology = "", use_spmd_for_xla_partitioning = false} -func.func @valid_compilation_cluster_no_replication() { - "tf.opA"() { _xla_compile_device_type = "TPU", is_stateless = true} : () -> () - "tf.opB"() { _xla_compile_device_type = "TPU", is_stateless = true} : () -> () - func.return -} - -// ----- - -// Check non-replicated case, empty op device to no device in cluster. -// CHECK: "tf_device.cluster"() -// CHECK: "tf.opA"() -// CHECK: "tf.opB"() -// CHECK: tf_device.return -// CHECK-NOT: device = -// CHECK: return -func.func @valid_compilation_cluster_no_replication_empty_op_device() { - "tf.opA"() { _xla_compile_device_type = "TPU", device = ""} : () -> () - "tf.opB"() { _xla_compile_device_type = "TPU", device = ""} : () -> () - func.return -} - - -// Check non-replicated case, including expected device attr in cluster. -// CHECK: "tf_device.cluster"() -// CHECK: "tf.opA"() -// CHECK: "tf.opB"() -// CHECK: device = "/device:TPU:1" -func.func @valid_compilation_cluster_no_replication_op_device() { - "tf.opA"() { _xla_compile_device_type = "TPU", device = "/device:TPU:1"} : () -> () - "tf.opB"() { _xla_compile_device_type = "TPU", device = "/device:TPU:1"} : () -> () - func.return -} - -// ----- - -// Check conflicting device names -// CHECK: "tf_device.cluster"() -// CHECK: "tf.opA"() -// CHECK: "tf.opB"() -// CHECK-NOT: device = -func.func @do_nothing_if_short_names_conflict() { - "tf.opA"() { _xla_compile_device_type = "TPU", device = "/replica:1/task:2/device:TPU:1"} : () -> () - "tf.opB"() { _xla_compile_device_type = "TPU", device = "/replica:3/task:4/device:TPU:1"} : () -> () - func.return -} - -// ----- - -// Check non-replicated case, including expected device attr in cluster. -// CHECK: "tf_device.cluster"() -// CHECK: "tf.opA"() -// CHECK: "tf.opB"() -// CHECK: device = "/task:0/device:TPU:1" -func.func @valid_compilation_cluster_no_replication_op_device() { - "tf.opA"() { _xla_compile_device_type = "TPU", device = "/task:0/device:TPU:1"} : () -> () - "tf.opB"() { _xla_compile_device_type = "TPU", device = "/device:TPU:1"} : () -> () - func.return -} - -// ----- - -// Check non-replicated case, including expected device attr in cluster. -// CHECK: "tf_device.cluster"() -// CHECK: "tf.opA"() -// CHECK: "tf.opB"() -// CHECK: device = "/task:0/device:TPU:1" -func.func @valid_compilation_cluster_no_replication_op_device() { - "tf.opA"() { _xla_compile_device_type = "TPU", device = "/device:TPU:1"} : () -> () - "tf.opB"() { _xla_compile_device_type = "TPU", device = "/task:0/device:TPU:1"} : () -> () - func.return -} - -// ----- - -// Check non-replicated case, empty op device to no device in cluster. -// CHECK: "tf_device.cluster"() -// CHECK: "tf.opA"() -// CHECK: "tf.opB"() -// CHECK: tf_device.return -// CHECK-NOT: device = -// CHECK: return -func.func @valid_compilation_cluster_no_replication_op_device() { - "tf.opA"() { _xla_compile_device_type = "TPU", device = "/device:TPU:0"} : () -> () - "tf.opB"() { _xla_compile_device_type = "TPU", device = "/task:0/device:TPU:1"} : () -> () - func.return -} - -// ----- - -// Check non-replicated case, empty op device to no device in cluster. -// CHECK: "tf_device.cluster"() -// CHECK: "tf.opA"() -// CHECK: "tf.opB"() -// CHECK: tf_device.return -// CHECK-NOT: device = -// CHECK: return -func.func @valid_compilation_cluster_no_replication_op_device() { - "tf.opA"() { _xla_compile_device_type = "TPU", device = "/device:CPU:0"} : () -> () + %0 = "tf.opA"() {_xla_compile_device_type = "XPU", _replication_info = "replicate", device = "/device:TPU:0", name = "name", is_stateless = true} : () -> tensor + "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", num_replicas = 1, topology = "topology"} : () -> () func.return } // ----- // expected-error@+1 {{found different '_xla_compile_device_type' attribute values (GPU,TPU) in same block which is not supported}} func.func @invalid_compilation_cluster_mixed_device_types() { - "tf.opA"() { _xla_compile_device_type = "GPU", is_stateless = true} : () -> () - "tf.opB"() { _xla_compile_device_type = "TPU", is_stateless = true} : () -> () + "tf.opA"() { _xla_compile_device_type = "GPU", _replication_info = "replicate", is_stateless = true} : () -> () + "tf.opB"() { _xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : () -> () + "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", num_replicas = 1, topology = "topology"} : () -> () func.return } @@ -871,37 +764,31 @@ func.func @invalid_compilation_cluster_mixed_device_types() { func.func @invalid_compilation_replication_cluster_mixed_device_types() { "tf.opA"() { _xla_compile_device_type = "CPU", _replication_info = "cluster", is_stateless = true} : () -> () "tf.opB"() { _xla_compile_device_type = "GPU", _replication_info = "cluster", is_stateless = true} : () -> () - func.return -} - -// ----- - -// expected-error@+1 {{found mixed replicated and non-replicated compiled ops in same block which is not supported}} -func.func @mixed_replicated_non_replicated_ops() { - "tf.opA"() { _xla_compile_device_type = "TPU", is_stateless = true} : () -> () - "tf.opB"() { _xla_compile_device_type = "TPU", _replication_info = "cluster", is_stateless = true} : () -> () + "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", num_replicas = 1, topology = "topology"} : () -> () func.return } // ----- func.func @cyclic_control_dependency_no_replication() { - "tf.opA"() {_xla_compile_device_type = "TPU"} : () -> () + "tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "replicate"} : () -> () // expected-warning-re@+1 {{Op has cyclic dependency with a compilation cluster{{.*}}}} "tf.opB"() : () -> () - "tf.opC"() {_xla_compile_device_type = "TPU"} : () -> () + "tf.opC"() {_xla_compile_device_type = "TPU", _replication_info = "replicate"} : () -> () + "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", num_replicas = 1, topology = "topology"} : () -> () func.return } // ----- func.func @cyclic_data_dependency_no_replication() { - %0 = "tf.opA"() {_xla_compile_device_type = "TPU", is_stateless = true} : () -> (tensor) + %0 = "tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : () -> (tensor) // expected-warning-re@+2 {{Op has cyclic dependency with a compilation cluster{{.*}}}} // expected-error@+1 {{operand #0 does not dominate this use}} %1 = "tf.opB"(%0) {is_stateless = true} : (tensor) -> (tensor) // expected-note@+1 {{operand defined here (op in the same block)}} - "tf.opC"(%1) {_xla_compile_device_type = "TPU", is_stateless = true} : (tensor) -> () + "tf.opC"(%1) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (tensor) -> () + "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", num_replicas = 1, topology = "topology"} : () -> () func.return } @@ -920,12 +807,12 @@ func.func @cyclic_control_dependency_replication() { // ----- func.func @cyclic_data_dependency_replication() { - %0 = "tf.opA"() {_xla_compile_device_type = "TPU", is_stateless = true} : () -> (tensor) + %0 = "tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "cluster", is_stateless = true} : () -> (tensor) // expected-warning-re@+2 {{Op has cyclic dependency with a compilation cluster{{.*}}}} // expected-error@+1 {{operand #0 does not dominate this use}} %1 = "tf.opB"(%0) {is_stateless = true} : (tensor) -> (tensor) // expected-note@+1 {{operand defined here (op in the same block)}} - "tf.opC"(%1) {_xla_compile_device_type = "TPU", is_stateless = true} : (tensor) -> () + "tf.opC"(%1) {_xla_compile_device_type = "TPU", _replication_info = "cluster", is_stateless = true} : (tensor) -> () "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () func.return } @@ -935,6 +822,7 @@ func.func @cyclic_data_dependency_replication() { // expected-warning@+1 {{TPUReplicateMetadata for associated '_replication_info' attribute 'cluster' is missing}} func.func @missing_metadata() { "tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "cluster"} : () -> () + "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () func.return } @@ -1035,11 +923,12 @@ func.func @gather_nd(%arg0: tensor<*x!tf_type.resource>>, Tindices = i32 } : (tensor<*x!tf_type.resource>>, tensor) -> tensor<1x80xf32> %2 = "tf.Add"(%1, %1) { - _xla_compile_device_type = "TPU", + _xla_compile_device_type = "TPU", _replication_info = "cluster", device = "/task:0/device:TPU:0", dtype = f32 } : (tensor<1x80xf32>, tensor<1x80xf32>) -> tensor<1x80xf32> %3 = "tf.ResourceGatherNd"(%arg0, %0) { Tindices = i32 } : (tensor<*x!tf_type.resource>>, tensor) -> tensor<1x80xf32> + "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "cluster", device = "/device:TPU:0", num_replicas = 1, topology = "topology"} : () -> () func.return } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index db28242944434e..a148b78bf42332 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -605,41 +605,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests user given device in cluster_func is propagated correctly. - -module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { - // CHECK-LABEL: func @no_replication_device - func.func @no_replication_device() { - "tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "__no_replication_cluster", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device = "/job:worker/replica:0/task:0/device:TPU:1", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () - // CHECK: "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:TPU:1"}> - // CHECK: tf.TPUExecute - // CHECK-NEXT: tf_device.return - func.return - } - func.func @empty_func() { - func.return - } -} - -// ----- - -// Tests CPU given device in cluster_func is not propagated. - -module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { - // CHECK-LABEL: func @no_replication_device - func.func @no_replication_device() { - "tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "__no_replication_cluster", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device = "/job:worker/replica:0/task:0/device:CPU:0", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () - // CHECK: "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:TPU:0"}> - // CHECK: tf.TPUExecute - // CHECK-NEXT: tf_device.return - func.return - } - func.func @empty_func() { - func.return - } -} - -// ----- // Tests metadata is populated correctly for use_spmd_for_xla_partitioning == // true. @@ -2579,8 +2544,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { - func.func @missing_compilation_attribute() { - // expected-error@+1 {{'tf_device.cluster_func' op has '_replication_info' attribute but not '_xla_compile_device_type' attribute which is unsupported}} + func.func @missing_compilation_and_replication_attributes() { + // expected-error@+1 {{'tf_device.cluster_func' op is expected to have either both or none of '_replication_info' and '_xla_compile_device_type' attributes}} "tf_device.cluster_func"() {_replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () func.return } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir index 4af8cdf06f727f..295079b24fe799 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func @num_replicas_replicated func.func @num_replicas_replicated(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { %0:2 = tf_executor.graph { - %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster"} : (tensor) -> tensor %ro:2, %c2 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%out) : (tensor) -> (tensor, tensor) @@ -16,7 +16,7 @@ func.func @num_replicas_replicated(%arg0: tensor, %arg1: tensor, %arg2 func.func @num_replicas_replicated_input(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { %0:2 = tf_executor.graph { - %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () // expected-error @+1 {{'tf.TPUReplicatedInput' op TF2XLA TPU bridge input check: number of inputs inconsistent. num_replicas=2 no. of inputs=3}} %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor, tensor) -> tensor %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster"} : (tensor) -> tensor @@ -30,7 +30,7 @@ func.func @num_replicas_replicated_input(%arg0: tensor, %arg1: tensor, func.func @num_replicas_replicated_input_packed(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { %0:2 = tf_executor.graph { - %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () // expected-error @+1 {{'tf.TPUReplicatedInput' op TF2XLA TPU bridge input check: packed with number of inputs not 1. num_replicas=2 no. of inputs=2}} %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = true} : (tensor, tensor) -> tensor %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster"} : (tensor) -> tensor @@ -44,7 +44,7 @@ func.func @num_replicas_replicated_input_packed(%arg0: tensor, %arg1: tenso func.func @num_replicas_replicated_output(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { %0:2 = tf_executor.graph { - %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster"} : (tensor) -> tensor // expected-error @+1 {{'tf.TPUReplicatedOutput' op TF2XLA TPU bridge input check: number of outputs inconsistent. num_replicas=2 no. of outputs=3}} @@ -58,7 +58,7 @@ func.func @num_replicas_replicated_output(%arg0: tensor, %arg1: tensor func.func @num_core_per_replica_partitioned_input(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { %0:2 = tf_executor.graph { - %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () // expected-error @+1 {{'tf.TPUPartitionedInput' op TF2XLA TPU bridge input check: number of inputs inconsistent. num_cores_per_replica=2 no. of inputs=3}} %pi, %c0 = tf_executor.island wraps "tf.TPUPartitionedInput"(%arg0, %arg1, %arg1) {index = 1 : i64} : (tensor, tensor, tensor) -> tensor %out, %c1 = tf_executor.island wraps "tf.opA"(%pi) {_tpu_replicate = "cluster"} : (tensor) -> tensor @@ -72,7 +72,7 @@ func.func @num_core_per_replica_partitioned_input(%arg0: tensor, %arg1: ten func.func @num_core_per_replica_partitioned_output(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { %0:2 = tf_executor.graph { - %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () %pi, %c0 = tf_executor.island wraps "tf.TPUPartitionedInput"(%arg0, %arg1) {index = 1 : i64} : (tensor, tensor) -> tensor %out, %c1 = tf_executor.island wraps "tf.opA"(%pi) {_tpu_replicate = "cluster"} : (tensor) -> tensor // expected-error @+1 {{'tf.TPUPartitionedOutput' op TF2XLA TPU bridge input check: number of outputs inconsistent. num_cores_per_replica=2 no. of outputs=3}} @@ -86,7 +86,7 @@ func.func @num_core_per_replica_partitioned_output(%arg0: tensor, %arg1: te func.func @validate_tpu_replicate_no_attr(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { %0:2 = tf_executor.graph { - %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate="cluster"}: (tensor) -> tensor // expected-warning @+1 {{TF2XLA TPU bridge input check: cluster op = tf.opA with cluster = cluster has successor as non cluster op tf.opB}} @@ -102,7 +102,7 @@ func.func @validate_tpu_replicate_no_attr(%arg0: tensor, %arg1: tensor func.func @validate_tpu_replicate_wrong_attr(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { %0:2 = tf_executor.graph { - %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster_wrong"}: (tensor) -> tensor // expected-error @+1 {{'tf.opB' op TF2XLA TPU bridge input check: mismatch clusters tpu_replicate attr. Parent op tf.opA with cluster = cluster_wrong has successor cluster op tf.opB with cluster = cluster}} @@ -117,7 +117,7 @@ func.func @validate_tpu_replicate_wrong_attr(%arg0: tensor, %arg1: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { %0:2 = tf_executor.graph { - %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster", device = "TPU"} : (tensor) -> tensor %ro:2, %c2 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%out) : (tensor) -> (tensor, tensor) @@ -130,7 +130,7 @@ func.func @valid_xla_nonxla(%arg0: tensor, %arg1: tensor, %arg2: tenso func.func @valid_xla_nonxla_warning(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor<*x!tf_type.string>, tensor<*x!tf_type.string>) { %0:2 = tf_executor.graph { - %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor<*x!tf_type.string> // expected-warning @+1 {{TF/XLA TPU bridge input check: found invalid op. tf.Identity can't be both xla and non-xla}} %out, %c1 = tf_executor.island(%c0) wraps "tf.Identity"(%ri) {_tpu_replicate = "cluster", device = ""} : (tensor<*x!tf_type.string>) -> tensor<*x!tf_type.string> @@ -151,7 +151,7 @@ func.func @valid_xla_nonxla_warning(%arg0: tensor, %arg1: tensor, %arg func.func @valid_MAXIMAL_sharding_device(%arg0: tensor) -> tensor { %0 = tf_executor.graph { - %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () + %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () %0, %c = tf_executor.island wraps "tf.Identity"(%arg0) {_tpu_replicate = "cluster", _XlaSharding = "\08\01\1A\01\01\22\01\00"} : (tensor) -> tensor tf_executor.fetch %0 : tensor } @@ -168,7 +168,7 @@ func.func @valid_MAXIMAL_sharding_device(%arg0: tensor) -> tensor { func.func @invalid_MAXIMAL_sharding_device(%arg0: tensor) -> tensor { %0 = tf_executor.graph { - %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () + %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () // expected-error @+1 {{'tf.Identity' op TF2XLA TPU bridge input check: invalid sharding device 2 for num_cores_per_replica = 2}} %0, %c = tf_executor.island wraps "tf.Identity"(%arg0) {_tpu_replicate = "cluster", _XlaSharding = "\08\01\1A\01\01\22\01\02"} : (tensor) -> tensor tf_executor.fetch %0 : tensor @@ -194,7 +194,7 @@ func.func @invalid_MAXIMAL_sharding_device(%arg0: tensor) -> tensor { func.func @invalid_TUPLE_sharding_arity(%arg0: tensor) -> tensor { %0 = tf_executor.graph { - %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () + %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () // expected-error @+1 {{'tf.Identity' op TF2XLA TPU bridge input check: invalid no. of tuple shardings 2 for arity = 1}} %0, %c = tf_executor.island wraps "tf.Identity"(%arg0) {_tpu_replicate = "cluster", _XlaSharding = "\08\02\2a\08\08\01\1a\01\01\22\01\00\2a\08\08\01\1a\01\01\22\01\01"} : (tensor) -> tensor tf_executor.fetch %0 : tensor @@ -220,11 +220,36 @@ func.func @invalid_TUPLE_sharding_arity(%arg0: tensor) -> tensor { func.func @outfeed_enqueue_tuple_sharding_exception(%arg0: tensor, %arg1: tensor) -> tensor { %0 = tf_executor.graph { - %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () + %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () %0, %c0 = tf_executor.island wraps "tf.AddV2"(%arg0, %arg1) {_tpu_replicate = "cluster"} : (tensor, tensor) -> tensor %c1 = tf_executor.island wraps "tf.OutfeedEnqueueTuple"(%arg0, %arg1) {_tpu_replicate = "cluster", _XlaSharding = "\08\02\2a\08\08\01\1a\01\01\22\01\00\2a\08\08\01\1a\01\01\22\01\01"} : (tensor, tensor) -> () tf_executor.fetch %0 : tensor } return %0 : tensor } -// ----- \ No newline at end of file + + +// ----- + +func.func @single_core_tpu(%arg0: tensor) -> () { + tf_executor.graph { + // expected-error @+1 {{found a single-core TPU graph}} + tf_executor.island wraps "tf.Identity"(%arg0) {_xla_compile_device_type = "TPU"} : (tensor) -> tensor + tf_executor.fetch + } + return +} + +// ----- + +// CHECK-LABEL: func @num_replicas_1 +func.func @num_replicas_1(%arg0: tensor) -> (tensor) { + %0 = tf_executor.graph { + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 1, num_cores_per_replica = 1, topology = "topology"} : () -> () + %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor) -> tensor + %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %ro, %c2 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%out) : (tensor) -> tensor + tf_executor.fetch %ro : tensor + } + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_validate_inputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_validate_inputs.mlir new file mode 100644 index 00000000000000..1beb284ce9509a --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_validate_inputs.mlir @@ -0,0 +1,21 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-xla-validate-inputs + +// expected-error @+1 {{expects no nested calls of entry functions as they prevent graph traversal in some passes from working correctly}} +func.func @nested_entry_functions() attributes {tf.entry_function = {}} { + tf_executor.graph { + %control = tf_executor.island wraps "tf.StatefulPartitionedCall"() {config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @func} : () -> () + tf_executor.fetch + } + func.return +} + +func.func @func() attributes {tf.entry_function = {}} { + func.return +} + +// ----- + +// expected-error @+1 {{does not support top-level compilation marker}} +func.func @top_level_compilation_marker() attributes {_xla_compile_device_type = "CPU", tf.entry_function = {}} { + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_validate_iputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_validate_iputs.mlir deleted file mode 100644 index 81f3398321f569..00000000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/xla_validate_iputs.mlir +++ /dev/null @@ -1,20 +0,0 @@ -// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-xla-validate-inputs - -// expected-error @+1 {{TF2XLA MLIR CPU/GPU phase 1 bridge expects no nested calls of entry functions as they prevent graph traversal in some passes from working correctly}} -func.func @nested_entry_functions(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { - %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @func} : (tensor) -> (tensor) - func.return %0 : tensor -} - -func.func @func(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { - func.return %arg0 : tensor -} - -// ----- - -// expected-error @+1 {{TF2XLA MLIR CPU/GPU MLIR phase 1 bridge expects single region and single block in an entry function}} -func.func @multi_blocks_entry_function(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { - cf.br ^bb1 -^bb1: - func.return %arg0 : tensor -} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index 2e090224a5c86c..4daaf633212451 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -147,6 +147,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/core:framework", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -298,6 +299,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], ) @@ -531,7 +533,6 @@ cc_library( "tpu_reorder_replicate_and_partitioned_inputs.cc", "tpu_resource_partitioning.cc", "tpu_resource_read_for_write.cc", - "tpu_sharding_identification_pass.cc", "tpu_space_to_depth_pass.cc", "tpu_update_embedding_enqueue_op_inputs.cc", "tpu_validate_inputs.cc", @@ -777,6 +778,7 @@ cc_library( ":tf_pass_inc_gen", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", "//tensorflow/compiler/mlir/tensorflow:shape_inference_utils", "//tensorflow/compiler/mlir/tensorflow:translate_utils", @@ -786,6 +788,9 @@ cc_library( "//tensorflow/core/ir/types:Dialect", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", @@ -796,6 +801,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", + "@local_tsl//tsl/platform:errors", "@local_xla//xla:shape_util", "@local_xla//xla:window_util", "@local_xla//xla:xla_data_proto_cc", @@ -1029,6 +1035,7 @@ cc_library( "//tensorflow/core:framework", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -1040,6 +1047,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_xla//xla:shape_util", "@local_xla//xla/mlir_hlo", "@local_xla//xla/stream_executor/tpu:c_api_conversions", @@ -1062,6 +1070,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:path", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc index 996686eb525d03..52765fb5657eba 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -70,14 +71,14 @@ void AnnotateParameterReplicationPass::runOnOperation() { if (mirrored_variable_indices_attr) { for (const auto& mirrored_index : mirrored_variable_indices_attr) { mirrored_replicate_args.insert( - mirrored_index.cast().getInt()); + mlir::cast(mirrored_index).getInt()); } } auto func = llvm::cast(m.lookupSymbol(cluster_func.getFunc())); for (auto entry : llvm::enumerate(cluster_func.getOperands())) { auto operand = SkipIdentityAndReadVariable(entry.value()); - auto block_arg = operand.dyn_cast(); + auto block_arg = mlir::dyn_cast(operand); if (block_arg && block_arg.getOwner() == &replicate.GetBody()) { // Only mirrored args of ReplicateOp can be annotated. if (mirrored_replicate_args.count(block_arg.getArgNumber()) == 0) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc index 14ab17be0fdee5..c6e21cb1e03054 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc @@ -53,8 +53,8 @@ class ConvertTFBatchMatMulToEinsumOp Value input_rhs = op.getY(); // LHS and RHS must be a ranked tensor type - auto lhs_type = input_lhs.getType().dyn_cast(); - auto rhs_type = input_rhs.getType().dyn_cast(); + auto lhs_type = mlir::dyn_cast(input_lhs.getType()); + auto rhs_type = mlir::dyn_cast(input_rhs.getType()); if (!lhs_type || !rhs_type) return failure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc index 4b409ffe1f614f..3ce5fb5bcb8379 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #define DEBUG_TYPE "cluster-ops-by-policy" @@ -44,7 +45,7 @@ ValueConstraint Merge(ValueConstraint a, ValueConstraint b) { LogicalResult IsStaticallyResolved(Value value, ValueConstraint constraint) { // Resolve constraints inferred from the tensor type. - if (auto tensor = value.getType().dyn_cast()) { + if (auto tensor = mlir::dyn_cast(value.getType())) { if (constraint == ValueConstraint::kRank && tensor.hasRank()) return success(); if (constraint == ValueConstraint::kShape && tensor.hasStaticShape()) @@ -710,7 +711,7 @@ void EmitValueConstraintsRemarks(const ValuesConstraintSet &constraints) { void EmitInputsConstraintsRemarks(func::FuncOp func, const ValuesConstraintSet &constraints) { constraints.Walk([&](Value value, ValueConstraint constraint) { - if (auto arg = value.dyn_cast()) + if (auto arg = mlir::dyn_cast(value)) if (arg.getOwner() == &func.getBody().front()) func.emitRemark(llvm::formatv("input #{0} constrained to: {1}", arg.getArgNumber(), constraint)); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc index cbe4ae6b2e41b1..355aded4f2d97a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc @@ -134,10 +134,6 @@ void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table, auto cluster_func_op = builder->create( cluster_op.getLoc(), outlined_func.getFunctionType().getResults(), live_ins.getArrayRef(), cluster_op->getAttrs()); - auto device_attr = cluster_op->getAttrOfType(TF::kDeviceAttr); - if (device_attr && !device_attr.getValue().empty()) { - cluster_func_op->setAttr(TF::kDeviceAttr, device_attr); - } cluster_op.replaceAllUsesWith(cluster_func_op); cluster_op.erase(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc index 5d9f5f9718446f..3d3e1305993a30 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc @@ -33,6 +33,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/util/device_name_utils.h" @@ -124,7 +125,7 @@ std::optional> GetFunctionMetadatas( // If the value is defined as an argument of the func_op, adds it to // the argument list of the function that uses this op. - if (BlockArgument block_arg = value.dyn_cast()) { + if (BlockArgument block_arg = mlir::dyn_cast(value)) { if (StringAttr attr = func_op.getArgAttrOfType( block_arg.getArgNumber(), kTFDeviceAttr)) { value_device = attr.getValue().str(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc index 59faa220521f0b..5a83e75e9eedf4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc @@ -62,7 +62,7 @@ Value GetR1Const(ArrayRef r1, OpBuilder builder, Location loc, Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder, Location loc) { - auto buffer_type = buffer.getType().cast(); + auto buffer_type = mlir::cast(buffer.getType()); if (buffer_type.getShape().size() == 1) return index; // Create a concat of index and trailing zeros. llvm::SmallVector zeros(buffer_type.getShape().size() - 1, 0); @@ -77,7 +77,7 @@ Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder, Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc, bool keep_slice_shape) { - auto buffer_type = buffer.getType().cast(); + auto buffer_type = mlir::cast(buffer.getType()); // Create a slice then reshape to remove the leading trivial dimension of // size 1. llvm::SmallVector slice_size = @@ -102,7 +102,7 @@ Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc, Value SetElement(Value index, Value buffer, Value element, OpBuilder builder, Location loc) { - auto buffer_type = buffer.getType().cast(); + auto buffer_type = mlir::cast(buffer.getType()); // Reshape the element to add a leading dimension of size 1 if th element does // not have that dimension, then perform a dynamic update slice. auto slice_shape = llvm::to_vector<8>(buffer_type.getShape()); @@ -208,7 +208,7 @@ std::optional GetElementTypeFromAccess( if (type_from_alias.has_value()) return type_from_alias; } else if (auto type = infer_from_op(use.getOwner())) { if (!type) continue; - auto elem_type = type->dyn_cast(); + auto elem_type = mlir::dyn_cast(*type); if (elem_type && elem_type.hasStaticShape()) return elem_type; } } @@ -220,8 +220,8 @@ Value ReadLocalVariable(Value local_var, OpBuilder builder, Location loc) { return builder .create( loc, - ArrayRef{getElementTypeOrSelf(local_var.getType()) - .cast() + ArrayRef{mlir::cast( + getElementTypeOrSelf(local_var.getType())) .getSubtypes()[0]}, ArrayRef{local_var}) .getValue(); @@ -246,7 +246,7 @@ Value AccumulateBuffers(Value a, Value b, OpBuilder builder, Location loc) { namespace { int64_t GetFirstIfIndicesAreContiguous(Value indices) { - auto type = indices.getType().dyn_cast(); + auto type = mlir::dyn_cast(indices.getType()); if (!type) return -1; auto indices_op = indices.getDefiningOp(); if (!indices_op) return -1; @@ -270,9 +270,10 @@ int64_t GetFirstIfIndicesAreContiguous(Value indices) { Value GatherElements(Value indices, Value buffer, OpBuilder builder, Location loc) { - auto buffer_type = buffer.getType().cast(); + auto buffer_type = mlir::cast(buffer.getType()); auto result_shape = llvm::to_vector<8>(buffer_type.getShape()); - result_shape[0] = indices.getType().cast().getDimSize(0); + result_shape[0] = + mlir::cast(indices.getType()).getDimSize(0); int64_t maybe_contiguous_start = GetFirstIfIndicesAreContiguous(indices); if (maybe_contiguous_start >= 0) { llvm::SmallVector slice_starts(result_shape.size(), 0); @@ -293,8 +294,8 @@ Value GatherElements(Value indices, Value buffer, OpBuilder builder, Value ScatterAccumulateElements(Value indices, Value updates, Value buffer, OpBuilder builder, Location loc) { - auto buffer_type = buffer.getType().cast(); - auto updates_type = updates.getType().cast(); + auto buffer_type = mlir::cast(buffer.getType()); + auto updates_type = mlir::cast(updates.getType()); int64_t maybe_contiguous_start = GetFirstIfIndicesAreContiguous(indices); if (maybe_contiguous_start == 0 && buffer_type == updates_type) { return AccumulateBuffers(buffer, updates, builder, loc); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 880dfa837e881c..ca5eb4bc737b99 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -47,7 +48,7 @@ static bool IsFoldedByDefaultPolicy(Operation* inst) { auto get_size = [&](TypeRange types) { int64_t size = 0; for (auto t : types) { - auto tensor_type = t.cast(); + auto tensor_type = mlir::cast(t); // Ignore types with undefined bit widths. if (!tensor_type.getElementType().isIntOrFloat()) continue; if (!tensor_type.hasStaticShape()) { @@ -93,7 +94,7 @@ LogicalResult ConstantFoldFallbackHook( // propagation. bool has_empty_numerical_results = llvm::all_of(inst->getResultTypes(), [](Type ty) { - ShapedType shaped_ty = ty.cast(); + ShapedType shaped_ty = mlir::cast(ty); Type element_ty = shaped_ty.getElementType(); return shaped_ty.hasStaticShape() && shaped_ty.getNumElements() == 0 && element_ty.isIntOrFloat(); @@ -103,7 +104,7 @@ LogicalResult ConstantFoldFallbackHook( // addressed. inst->isRegistered()) { for (Type ty : inst->getResultTypes()) { - auto shaped_ty = ty.cast(); + auto shaped_ty = mlir::cast(ty); results.push_back( DenseElementsAttr::get(shaped_ty, llvm::ArrayRef())); } @@ -112,14 +113,14 @@ LogicalResult ConstantFoldFallbackHook( // Returns directly if any of the operands is not an elements attributes. if (std::any_of(operands.begin(), operands.end(), [](Attribute attr) { - return !attr || !attr.isa(); + return !attr || !mlir::isa(attr); })) return failure(); SmallVector inputs; inputs.reserve(operands.size()); for (auto input : operands) { - inputs.push_back(input.cast()); + inputs.push_back(mlir::cast(input)); } SmallVector constants; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc index 6d28fa03a988a6..84c96590910243 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" @@ -65,8 +66,8 @@ bool CanBeFolded(Operation* inst) { // This creates opaque variant constants which lose information and would // require "raising" later. for (const Type type : inst->getResultTypes()) { - if (const TensorType tensor_type = type.dyn_cast()) { - if (tensor_type.getElementType().isa()) { + if (const TensorType tensor_type = mlir::dyn_cast(type)) { + if (mlir::isa(tensor_type.getElementType())) { return false; } } @@ -134,7 +135,7 @@ LogicalResult EvaluateOperation(Operation* inst, node_def->get()->op(), node_def->get()->name(), host_cpu, operands.size(), [&](tensorflow::AttrValueMap* attr_value_map) { *attr_value_map = node_def->get()->attr(); - return tensorflow::OkStatus(); + return absl::OkStatus(); }, fallback_state.device_manager(), fallback_state.process_function_library_runtime()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc index 4de43317677f63..6262cad26ca6e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc @@ -114,7 +114,7 @@ SmallVector GetWhileCallers(func::FuncOp func, } bool IsResourceType(Type type) { - return getElementTypeOrSelf(type).isa(); + return mlir::isa(getElementTypeOrSelf(type)); } bool OnlyOperatesOnCompositeDevices( @@ -124,11 +124,11 @@ bool OnlyOperatesOnCompositeDevices( auto& alias_analysis = side_effect_analysis.GetAliasAnalysis(); llvm::SmallSet read_array; for (const Attribute& attr : op.getDeviceVarReadsIndices()) { - read_array.insert(attr.cast().getInt()); + read_array.insert(mlir::cast(attr).getInt()); } llvm::SmallSet update_array; for (const Attribute& attr : op.getDeviceVarUpdatesIndices()) { - update_array.insert(attr.cast().getInt()); + update_array.insert(mlir::cast(attr).getInt()); } for (auto& arg : op->getOpOperands()) { @@ -270,7 +270,7 @@ void CollectChainResources( // // Checks if the value `control` is a NoOp control barrier. bool IsNoOpControlBarrier(Value control) { - if (!control.getType().isa()) return false; + if (!mlir::isa(control.getType())) return false; auto control_island = dyn_cast_or_null(control.getDefiningOp()); if (!control_island) return false; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc index 8e89f3988dd8d4..4af1246d5a72b6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" @@ -76,7 +77,7 @@ AnonymousIteratorV3Op CreateIterator(OpBuilder builder, llvm::SmallVector type_attrs; for (Type type : dataset_types) { shape_attrs.push_back( - TF::ShapeAttr::get(builder.getContext(), type.cast())); + TF::ShapeAttr::get(builder.getContext(), mlir::cast(type))); type_attrs.push_back(TypeAttr::get(getElementTypeOrSelf(type))); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index 51438ac4901b9d..4cdc90376c2317 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -90,7 +90,7 @@ TF::SumOp createSumOp(Value value, Location loc, PatternRewriter* rewriter) { Value redux_op = createI32ConstantOp(redux_axes, loc, rewriter); - auto value_type = value.getType().cast(); + auto value_type = mlir::cast(value.getType()); auto shape = value_type.getShape(); llvm::SmallVector sum_shape; for (int i = 0; i < shape.size(); ++i) { @@ -108,7 +108,7 @@ TF::TransposeOp createTransposeOp(Value value, Location loc, llvm::ArrayRef permutation, PatternRewriter* rewriter) { auto perm_op = createI32ConstantOp(permutation, loc, rewriter); - auto value_type = value.getType().cast(); + auto value_type = mlir::cast(value.getType()); auto shape = value_type.getShape(); SmallVector transposed_shape(shape.begin(), shape.end()); for (int i = 0, end = shape.size(); i < end; ++i) { @@ -529,7 +529,7 @@ LogicalResult rewriteToReduceSumAndTranspose(TF::EinsumOp op, bool needs_transpose = false; for (int64_t i = 0; i < dnums.lhs_out.size(); ++i) { if (std::get<0>(dnums.lhs_out[i]) > - lhs.getType().cast().getRank() - 1) { + mlir::cast(lhs.getType()).getRank() - 1) { continue; } @@ -637,8 +637,8 @@ LogicalResult reshapeForBatchMatmul(const Location& loc, Value* rhs, SmallVectorImpl* out_shape, PatternRewriter* rewriter) { - RankedTensorType lhs_type = lhs->getType().cast(); - RankedTensorType rhs_type = rhs->getType().cast(); + RankedTensorType lhs_type = mlir::cast(lhs->getType()); + RankedTensorType rhs_type = mlir::cast(rhs->getType()); int32_t num_lhs_reshape_segids = 0; int32_t num_rhs_reshape_segids = 0; @@ -776,7 +776,7 @@ LogicalResult rewriteToBatchMatmul(TF::EinsumOp op, EinsumDimensionNumbers original_dnums = dnums; RankedTensorType original_type = - op.getResult().getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op.getResult().getType()); if (!original_type) return failure(); std::vector out_transpose; @@ -822,7 +822,7 @@ LogicalResult matchAndRewriteUnaryEinsumOp(TF::EinsumOp op, op, "Function only supports unary einsum op"); } RankedTensorType lhs = - op.getOperand(0).getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op.getOperand(0).getType()); if (!lhs) { return failure(); } @@ -862,9 +862,9 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite( } RankedTensorType lhs = - op.getOperand(0).getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op.getOperand(0).getType()); RankedTensorType rhs = - op.getOperand(1).getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op.getOperand(1).getType()); if (!lhs || !rhs) { return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc index 51afea6d84671e..e1611432f36e8c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/core/platform/logging.h" @@ -443,7 +444,7 @@ void InsertDummyIslandForFetch(FetchOp fetch) { control_fetches.reserve(data_fetches.capacity()); for (auto value : fetch.getFetches()) { - if (value.getType().isa()) { + if (mlir::isa(value.getType())) { control_fetches.push_back(value); } else { data_fetches.push_back(value); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc index 9567278d98dc9c..b75f081d1a0064 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc @@ -395,7 +395,7 @@ bool is_valid_special_tpu_op( bool op_has_inconsistent_cluster_name = wrapped_op_cluster_name.has_value() && - !wrapped_op_cluster_name.value().equals(cluster_name); + wrapped_op_cluster_name.value() != cluster_name; if (op_has_inconsistent_cluster_name) { return false; @@ -624,7 +624,7 @@ void TpuV1BridgeExecutorIslandCoarsening::runOnOperation() { assert(!funcs_for_cluster->second.empty()); if (funcs_for_cluster->second.size() == 1) return false; for (NamedAttribute attr : op->getAttrs()) { - auto symbol_ref = attr.getValue().dyn_cast(); + auto symbol_ref = mlir::dyn_cast(attr.getValue()); if (!symbol_ref) continue; func::FuncOp callee = symbol_table.lookup(symbol_ref.getValue()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc index 0106d149d3d343..19603170b89e20 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc @@ -178,13 +178,14 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() { for (func::FuncOp func : outlined_module.getOps()) { func.walk([&](Operation *op) { for (NamedAttribute attr : op->getAttrs()) { - if (auto symbol_ref = attr.getValue().dyn_cast()) { + if (auto symbol_ref = + mlir::dyn_cast(attr.getValue())) { MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table); continue; } - if (auto array_attr = attr.getValue().dyn_cast()) { + if (auto array_attr = mlir::dyn_cast(attr.getValue())) { for (const Attribute &attribute : array_attr) { - auto symbol_ref = attribute.dyn_cast(); + auto symbol_ref = mlir::dyn_cast(attribute); if (!symbol_ref) continue; MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/extract_tpu_copy_with_dynamic_shape_op.cc b/tensorflow/compiler/mlir/tensorflow/transforms/extract_tpu_copy_with_dynamic_shape_op.cc index dfd20d8dd0e07a..18480fbd772fa9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/extract_tpu_copy_with_dynamic_shape_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/extract_tpu_copy_with_dynamic_shape_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -47,7 +48,7 @@ class ExtractTPUCopyWithDynamicShapeOpPass // Finds op that created a given value. If the value is a BlockArgument, this // returns the owner of the Block. Operation* GetOpOfValue(Value value) { - if (auto block_arg = value.dyn_cast()) + if (auto block_arg = mlir::dyn_cast(value)) return block_arg.getOwner()->getParentOp(); return value.getDefiningOp(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc index d755696c74607b..6547b6f168c3bf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -144,7 +145,7 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Check that the result shape is fully defined. auto result_type = - op->getResultTypes().front().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op->getResultTypes().front()); if (!result_type || !result_type.hasStaticShape()) return failure(); bool changed = false; @@ -155,15 +156,13 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( if (!broadcast) continue; // Check that the operand of the broadcast has fully defined shape. - auto broadcast_arg_type = - broadcast.getInput().getType().dyn_cast_or_null(); + auto broadcast_arg_type = mlir::dyn_cast_or_null( + broadcast.getInput().getType()); if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue; // Check that the other argument has fully defined shape. - auto argument_type = op->getOpOperand(1 - i) - .get() - .getType() - .dyn_cast_or_null(); + auto argument_type = mlir::dyn_cast_or_null( + op->getOpOperand(1 - i).get().getType()); if (!argument_type || !argument_type.hasStaticShape()) continue; // Get the unbroadcasted shapes in the operand order. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc index 6a1a4852e68a1c..9f9da90bf76594 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc @@ -86,7 +86,7 @@ void FreezeGlobalTensorsPass::runOnOperation() { DenseMap freezeable; for (auto func : module.getOps()) { for (BlockArgument val : func.getArguments()) { - if (!getElementTypeOrSelf(val.getType()).isa()) + if (!mlir::isa(getElementTypeOrSelf(val.getType()))) continue; // Check that there is only a single global tensor associated with arg. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc index 0cff8946687dcb..11be79869f4fd2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -101,7 +102,7 @@ YieldOp CreateCall(Operation* op, func::FuncOp func, Region& caller_region, // Converts the condition for an IfOp/WhileOp to a boolean value. Value ConvertConditionToBoolean(Operation* op, Value cond) { - if (auto ranked_type = cond.getType().dyn_cast()) + if (auto ranked_type = mlir::dyn_cast(cond.getType())) if (ranked_type.getRank() == 0 && ranked_type.getElementType().isSignlessInteger(1)) return cond; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc index 1c0a125598cdbe..4eb791a909022d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc @@ -332,9 +332,9 @@ class FuseMatMulBiasAdd } // FusedMatMul kernel does not support grad_a/grad_b attrs if ((matmul->hasAttr("grad_a") && - matmul->getAttr("grad_a").cast().getValue()) || + mlir::cast(matmul->getAttr("grad_a")).getValue()) || (matmul->hasAttr("grad_b") && - matmul->getAttr("grad_b").cast().getValue())) { + mlir::cast(matmul->getAttr("grad_b")).getValue())) { (void)rewriter.notifyMatchFailure(matmul, [&](Diagnostic &diag) { diag << "FusedMatMul kernel does not support grad_a/grad_b attrs"; }); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc index dff2223b11567f..e49b0f445d0c70 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc @@ -43,7 +43,7 @@ Status MlirGraphOptimizationPass::Run( ::tensorflow::MlirOptimizationPassState::Disabled) { VLOG(1) << "Skipping MLIR Graph Optimization Pass" << ", session flag not enabled"; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } VLOG(1) << "Run MLIR Graph Optimization Passes"; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc index 78fb6aad3abdde..91f14794494de7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -135,7 +136,7 @@ void HoistLoopInvariantPass::runOnOperation() { // Skip the pass if the function inputs contain any resource. for (const auto &type : func.getArgumentTypes()) { - if (getElementTypeOrSelf(type).isa()) return; + if (mlir::isa(getElementTypeOrSelf(type))) return; } llvm::DenseSet read_only_vars = GetReadOnlyVariables(func); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD index f8e75d9032f3e5..3d046b4c41c51f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD @@ -252,6 +252,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", ], diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.cc index 767d5cf7f0cf8c..a21c78a9e3ca82 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" @@ -93,7 +94,7 @@ LogicalResult SetMetadataProtoStepMarkerLocation( // Parses a xla::OpSharding from a string attribute. LogicalResult SetOpSharding(Operation* op, Attribute attr, llvm::StringRef name, int index, xla::OpSharding* sharding_ptr) { - auto sharding_attr = attr.dyn_cast(); + auto sharding_attr = mlir::dyn_cast(attr); if (!sharding_attr) return op->emitOpError( llvm::formatv(kBadStringArrayElementMsg, name, index)); @@ -130,7 +131,7 @@ LogicalResult SetMetadataProtoArgs( llvm::SmallSet dynamic_arg_idx_set; if (dynamic_arg_idx) { for (auto idx : dynamic_arg_idx.getValue()) { - dynamic_arg_idx_set.insert(idx.dyn_cast().getInt()); + dynamic_arg_idx_set.insert(mlir::dyn_cast(idx).getInt()); } } @@ -155,7 +156,8 @@ LogicalResult SetMetadataProtoArgs( // Populate argument shapes. *arg->mutable_shape() = tensorflow::TensorShapeProto(); - if (auto ranked_tensor_type = operand_type.dyn_cast()) { + if (auto ranked_tensor_type = + mlir::dyn_cast(operand_type)) { tensorflow::TensorShapeProto shape_proto; ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto); *arg->mutable_shape() = std::move(shape_proto); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_rewrite_pass.cc index ed1e0549dfb769..d8067af3f29557 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_rewrite_pass.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "absl/strings/match.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" @@ -43,6 +42,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -75,7 +75,6 @@ namespace mlir { namespace TFTPU { constexpr char kStepMarkerLocationAttr[] = "step_marker_location"; -constexpr char kDeviceAttr[] = "device"; constexpr char kDevicesAttr[] = "devices"; constexpr char kVersionsAttr[] = "tf.versions"; constexpr char kUseXlaSpmdAttr[] = "use_spmd_for_xla_partitioning"; @@ -139,7 +138,7 @@ LogicalResult EncapsulateFuncAndSerialize(const std::string& module_name, assert(uses && "expected to be able to collect symbol uses"); for (SymbolTable::SymbolUse use : *uses) { func::FuncOp referenced_func = entry_module_table.lookup( - use.getSymbolRef().cast().getValue()); + mlir::cast(use.getSymbolRef()).getValue()); // Skip Symbols that do not map to a function. if (!referenced_func) continue; @@ -380,18 +379,9 @@ LogicalResult AddToParallelExecuteOp( // If computation is replicated, use aliased device. Otherwise there is only // one execution device per core and the device is assigned to the execute // op. - std::string device; - if (replicated) { - device = tensorflow::GetDeviceAliasForLogicalCore(core); - } else { - auto device_attr = cluster_func->getAttrOfType(kDeviceAttr); - if (device_attr && !device_attr.str().empty() && - absl::StrContains(device_attr.str(), "TPU:")) { - device = cluster_func->getAttrOfType(kDeviceAttr).str(); - } else { - device = tpu_devices.front()[core].device; - } - } + std::string device = replicated + ? tensorflow::GetDeviceAliasForLogicalCore(core) + : tpu_devices.front()[core].device; auto block_launch_op = tensorflow::WrapOpInLaunch( builder, block.getParent()->getLoc(), execute, device); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_variable_runtime_reformatting.cc index 271e525ef8c7ae..4e87c10b1b7ac6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_variable_runtime_reformatting.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -73,7 +74,7 @@ struct TPUVariableRuntimeReformattingPass // provided, it will be used to store the identity nodes skipped. Value SkipIdentity(Value v, bool allow_other_use, llvm::SmallPtrSet* skipped = nullptr) { - while (auto result = v.dyn_cast()) { + while (auto result = mlir::dyn_cast(v)) { if (!(allow_other_use || v.hasOneUse())) break; auto op = result.getDefiningOp(); if (!llvm::isa(op)) { @@ -108,10 +109,10 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( for (auto index_and_arg : llvm::enumerate(execute.getArgs())) { auto arg = SkipIdentity(index_and_arg.value(), /*allow_other_use=*/false); if (!arg.hasOneUse() || - !getElementTypeOrSelf(arg.getType()).isa()) { + !mlir::isa(getElementTypeOrSelf(arg.getType()))) { continue; } - auto block_arg = arg.dyn_cast(); + auto block_arg = mlir::dyn_cast(arg); if (!block_arg || block_arg.getOwner() != &replicate.GetBody()) continue; assert(replicate_arg_to_execute_arg.count(block_arg.getArgNumber()) == 0 && "Found duplicate use of a resource in the execute op."); @@ -131,13 +132,13 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( // variables (arguments of `replicate`), and must be pass-throughs from while // operands. for (const auto& mirrored_index : mirrored_variable_indices_attr) { - int64_t replicate_arg = mirrored_index.cast().getInt(); + int64_t replicate_arg = mlir::cast(mirrored_index).getInt(); // Check if the mirrored variable is an input to `execute`. auto it = replicate_arg_to_execute_arg.find(replicate_arg); if (it == replicate_arg_to_execute_arg.end()) continue; // Get the data type of the resource. - auto subtypes = getElementTypeOrSelf(execute.getOperand(it->second)) - .cast() + auto subtypes = mlir::cast( + getElementTypeOrSelf(execute.getOperand(it->second))) .getSubtypes(); if (subtypes.size() != 1) continue; auto data_type = getElementTypeOrSelf(subtypes[0]); @@ -198,14 +199,14 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( llvm::sort(mapping, llvm::less_first()); // Populate the `retval_index_for_sharding` field of the argument metadate. for (auto entry : llvm::enumerate(execute.getDeviceVarReadsIndices())) { - int64_t arg_index = entry.value().cast().getInt(); + int64_t arg_index = mlir::cast(entry.value()).getInt(); auto arg_metadata = metadata.mutable_args(arg_index); if (arg_metadata->enable_xla_sharding() == ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED) { - int64_t ret_index = execute.getDeviceVarUpdatesIndices() - .getValue()[entry.index()] - .cast() - .getInt(); + int64_t ret_index = + mlir::cast( + execute.getDeviceVarUpdatesIndices().getValue()[entry.index()]) + .getInt(); arg_metadata->set_retval_index_for_sharding(ret_index); } } @@ -379,12 +380,13 @@ bool HandleReplicateOp(TF::WhileRegionOp while_op, for (auto it : device_map) { auto device_alias = it.getName().strref(); - auto device_list = it.getValue().cast(); + auto device_list = mlir::cast(it.getValue()); llvm::SmallVector device_list_for_alias; device_list_for_alias.reserve(device_list.size()); for (auto device : device_list) - device_list_for_alias.emplace_back(device.cast().getValue()); + device_list_for_alias.emplace_back( + mlir::cast(device).getValue()); devices.insert({device_alias, device_list_for_alias}); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc b/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc index 3b974c395706aa..6e7fe42ef4dfab 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" @@ -99,8 +100,7 @@ func::FuncOp GetOrCreateSessionInitFunc(ModuleOp module) { // tf_saved_model.initializer_type attribute was introduced. SymbolTable symbol_table(module); return symbol_table.lookup( - session_init_op.getInitializers()[0] - .cast() + mlir::cast(session_init_op.getInitializers()[0]) .getValue()); } else { return CreateSessionInitFunc(module); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc index aa1efc6837eee6..015499c6996f38 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -66,7 +67,7 @@ LogicalResult AssignDevicesInRegion(const Dialect* tf_dialect, return WalkResult::advance(); } - if (auto device_str_attr = device_attr.dyn_cast()) { + if (auto device_str_attr = mlir::dyn_cast(device_attr)) { if (device_str_attr.getValue().empty()) { op->setAttr(kDeviceAttr, launch.getDeviceAttr()); return WalkResult::advance(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc index 0fad3c019ea432..e8c1d1997e195e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h" @@ -49,7 +50,7 @@ TransposeOp ReuseExistingTranspose(const OpOperand* operand, auto tranpose_op = *it; for (auto tranpose_operand : tranpose_op.getOperands()) { auto ranked_tranpose_type = - tranpose_operand.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(tranpose_operand.getType()); if (!ranked_tranpose_type) continue; if (ranked_tranpose_type.getRank() == permutation.size() && operand->get().getType() == @@ -201,7 +202,7 @@ void MoveTransposeBefore(Operation* op, SmallVector* work_list) { if (!perm) return; // With the same permutation indices. - auto dense_elem_attr = perm.getValue().dyn_cast(); + auto dense_elem_attr = mlir::dyn_cast(perm.getValue()); if (!dense_elem_attr) return; if (!permutation_op) permutation_op = perm; @@ -217,7 +218,7 @@ void MoveTransposeBefore(Operation* op, SmallVector* work_list) { // Nothing to do here. if (!permutation_op || transpose_ops.empty()) return; SmallVector permutation; - auto perm_attr = permutation_op.getValue().cast(); + auto perm_attr = mlir::cast(permutation_op.getValue()); for (const auto& value : perm_attr.getValues()) permutation.push_back(value.getSExtValue()); @@ -227,10 +228,11 @@ void MoveTransposeBefore(Operation* op, SmallVector* work_list) { if (op->hasTrait()) { auto transpose_op = *transpose_ops.begin(); auto result_type = - transpose_op.getResult().getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(transpose_op.getResult().getType()); auto is_valid_move = llvm::all_of(op->getOperands(), [result_type](Value operand) -> bool { - auto operand_type = operand.getType().dyn_cast_or_null(); + auto operand_type = + mlir::dyn_cast_or_null(operand.getType()); return result_type && operand_type && result_type.hasRank() && operand_type.hasRank() && result_type.getRank() == operand_type.getRank(); @@ -343,7 +345,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, if (!perm) return; // With the same permutation indices. - auto dense_elem_attr = perm.getValue().dyn_cast(); + auto dense_elem_attr = mlir::dyn_cast(perm.getValue()); if (!dense_elem_attr) return; if (!permutation_op) permutation_op = perm; @@ -365,7 +367,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, SmallVector permutation; - auto attr = permutation_op.getValue().cast(); + auto attr = mlir::cast(permutation_op.getValue()); for (const auto& value : attr.getValues()) permutation.push_back(value.getSExtValue()); @@ -373,7 +375,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, if (fold_operands && fold_transpose_in_ops) { SmallVector permutation; - auto attr = permutation_op.getValue().cast(); + auto attr = mlir::cast(permutation_op.getValue()); for (const auto& value : attr.getValues()) permutation.push_back(value.getSExtValue()); @@ -408,7 +410,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, // update the result type in `FoldOperandsPermutation`. if (layout_agnostic) result.setType(ReversePermuteShapedType( - result.getType().cast(), permutation)); + mlir::cast(result.getType()), permutation)); // Try to push transpose further down. for (Operation* user : result.getUsers()) { @@ -422,7 +424,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, transpose.getOperation()->moveBefore(op->getNextNode()); transpose.setOperand(0, result); transpose.setOperand(1, permutation_op); - transpose.getResult().setType(original_type[idx].cast()); + transpose.getResult().setType(mlir::cast(original_type[idx])); } else { transpose = builder.create(loc, result, permutation_op); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc index bc0534fdb0bb84..c4ea84d8b0948c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc @@ -189,10 +189,10 @@ LogicalResult LiftVariables(ModuleOp module, Session* session) { func, arg_number, symbol_table); if (!global_tensor) continue; - auto arg_type = arg.getType().cast(); + auto arg_type = mlir::cast(arg.getType()); assert(arg_type.getRank() == 0); llvm::ArrayRef underlying_type = - arg_type.getElementType().cast().getSubtypes(); + mlir::cast(arg_type.getElementType()).getSubtypes(); // If the arg type already matches the global_tensor type, we don't need // to do anything. @@ -206,7 +206,7 @@ LogicalResult LiftVariables(ModuleOp module, Session* session) { auto new_arg_type = mlir::RankedTensorType::get( /*shape=*/{}, mlir::TF::ResourceType::get( - /*subtypes=*/{global_tensor.getType().cast()}, + /*subtypes=*/{mlir::cast(global_tensor.getType())}, module.getContext())); arg.setType(new_arg_type); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc index cdb256ab25f177..8d58b8177b33c8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc @@ -47,7 +47,7 @@ static LogicalResult traceUpwardsToArgument(Value v, llvm::DenseSet seen, } seen.insert(v); - if (auto blockArg = v.dyn_cast()) { + if (auto blockArg = mlir::dyn_cast(v)) { Operation *op = blockArg.getOwner()->getParentOp(); // If we're in the first block, then the argument to that block is the diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index 1c9b1e03a663c6..da565f00b45b99 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeRange.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" @@ -79,7 +80,7 @@ static DenseElementsAttr GetF32Scalar(OpBuilder *builder, float value) { // Preconditions: The given value must have a ShapedType. static Value CreateTFCastOpF32(OpBuilder *builder, Location loc, Value x, BoolAttr truncate) { - auto x_type = x.getType().dyn_cast_or_null(); + auto x_type = mlir::dyn_cast_or_null(x.getType()); if (!x_type) llvm_unreachable("unsupported type"); Type type = x_type.clone(builder->getF32Type()); return builder->create(loc, type, x, truncate); @@ -92,7 +93,7 @@ static Value CreateTFCastOpF32(OpBuilder *builder, Location loc, Value x, // Preconditions: The given value must have a ShapedType. static Value CreateTFCastOpI32(OpBuilder *builder, Location loc, Value x, BoolAttr truncate) { - auto x_type = x.getType().dyn_cast_or_null(); + auto x_type = mlir::dyn_cast_or_null(x.getType()); if (!x_type) llvm_unreachable("unsupported type"); Type type = x_type.clone(builder->getI32Type()); return builder->create(loc, type, x, truncate); @@ -109,7 +110,8 @@ static APFloat ConvertToAPFloat(double val, Type type) { // Performs the operation of `Shape(input)[idx]`. static Value GetDimensionSize(OpBuilder *builder, Location loc, Value input, int32_t idx, BoolAttr use_32bit) { - if (auto ranked_ty = input.getType().dyn_cast_or_null()) { + if (auto ranked_ty = + mlir::dyn_cast_or_null(input.getType())) { // Canonicalize negative index. if (idx < 0) { idx += ranked_ty.getRank(); @@ -154,7 +156,7 @@ bool QuantizedTypeIsUnsigned(Type type) { // to offset the quantized representation before it gets scaled. In the case // of negative quantize types, this offset is half the type's range. static DenseElementsAttr DequantizeHalfRange(OpBuilder *builder, Value input) { - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type) llvm_unreachable("DequantizeHalfRange: not a ShapedType"); bool is_unsigned = QuantizedTypeIsUnsigned(input_type.getElementType()); float half_range = is_unsigned ? 0 : 128; @@ -183,7 +185,7 @@ DenseIntElementsAttr GetBiasAddGradReductionIndices(int64_t rank, // Infers ExpandDims op output type for the given input type `ty` and dimension // to expand at the given `axis`. Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) { - auto ranked_ty = ty.dyn_cast(); + auto ranked_ty = mlir::dyn_cast(ty); // Unranked type. if (!ranked_ty) return ty; @@ -258,7 +260,7 @@ class LowerAddNOp : public RewritePattern { // TODO(hinsu): Support variant with TensorList type. tf.AddV2 doesn't // support variant type so variant types require special handling. - if (getElementTypeOrSelf(addn_op.getType()).isa()) + if (mlir::isa(getElementTypeOrSelf(addn_op.getType()))) return failure(); llvm::SmallVector operands(addn_op.getInputs().begin(), addn_op.getInputs().end()); @@ -324,8 +326,7 @@ class LowerDynamicStitchOp : public RewritePattern { // Static output type is used to compute intermediate values. Note that the // output type doesn't have to be static but if input types and indices are // constant, then the output type can be statically determined. - RankedTensorType out_ty = - op.getType().template dyn_cast(); + RankedTensorType out_ty = mlir::dyn_cast(op.getType()); if (!out_ty || !out_ty.hasStaticShape()) return failure(); // Extract out all the constant indices' attributes and verify that data @@ -341,7 +342,7 @@ class LowerDynamicStitchOp : public RewritePattern { indices.push_back(index_attr); RankedTensorType data_ty = - data.getType().template dyn_cast(); + mlir::dyn_cast(data.getType()); if (!data_ty || !data_ty.hasStaticShape()) return failure(); } @@ -367,9 +368,8 @@ class LowerDynamicStitchOp : public RewritePattern { auto reshaped_data = rewriter.create(loc, data, packed_shape_val); - auto num_items = reshaped_data.getType() - .template cast() - .getShape()[0]; + auto num_items = + mlir::cast(reshaped_data.getType()).getShape()[0]; auto items = rewriter.create( loc, SmallVector(num_items, item_ty), reshaped_data, /*axis=*/0); @@ -407,7 +407,7 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern { auto op = cast(src_op); auto input = op.getInputs(); - auto input_ty = input.getType().cast(); + auto input_ty = mlir::cast(input.getType()); auto element_ty = input_ty.getElementType(); auto scalar_ty = tensorflow::GetTypeFromTFTensorShape({}, element_ty); @@ -534,7 +534,7 @@ class LowerInvertPermutationOp : public RewritePattern { auto op = cast(src_op); Location loc = op.getLoc(); - auto x_type = op.getX().getType().dyn_cast(); + auto x_type = mlir::dyn_cast(op.getX().getType()); // x input must have static shape. if (!x_type || !x_type.hasStaticShape()) { return failure(); @@ -617,12 +617,13 @@ class LowerLgammaOp : public RewritePattern { Location loc = op.getLoc(); Value input = op.getX(); - TensorType original_tensor_type = op.getX().getType().cast(); + TensorType original_tensor_type = + mlir::cast(op.getX().getType()); // The approximation is not precise enough for float16. Do the computation // in float32 for that case. TensorType tensor_type = original_tensor_type; - FloatType float_type = tensor_type.getElementType().cast(); + FloatType float_type = mlir::cast(tensor_type.getElementType()); bool needs_cast = float_type.getWidth() < 32; if (needs_cast) { MLIRContext *context = rewriter.getContext(); @@ -887,17 +888,18 @@ class LowerSpaceToBatchNDOp : public RewritePattern { auto op = cast(src_op); Location loc = op.getLoc(); - auto input_type = op.getInput().getType().cast(); + auto input_type = mlir::cast(op.getInput().getType()); auto element_type = input_type.getElementType(); if (!input_type.hasStaticShape()) { return failure(); } ArrayRef input_shape = input_type.getShape(); - auto block_shape_type = op.getBlockShape().getType().cast(); + auto block_shape_type = + mlir::cast(op.getBlockShape().getType()); if (!block_shape_type.hasStaticShape()) { return failure(); } - auto paddings_type = op.getPaddings().getType().cast(); + auto paddings_type = mlir::cast(op.getPaddings().getType()); if (!paddings_type.hasRank()) { return failure(); } @@ -1100,7 +1102,7 @@ class LowerBatchToSpaceND : public RewritePattern { PatternRewriter &rewriter) const override { auto op = cast(src_op); auto input = op.getInput(); - auto input_ty = input.getType().cast(); + auto input_ty = mlir::cast(input.getType()); auto element_ty = input_ty.getElementType(); if (!input_ty.hasStaticShape()) { return failure(); @@ -1279,9 +1281,7 @@ class LowerSparseMatMulOp : public RewritePattern { // Result type must be f32 for applying the pattern (currently this is // required by the op anyway but this might change). - if (!op.getProduct() - .getType() - .cast() + if (!mlir::cast(op.getProduct().getType()) .getElementType() .isF32()) { return failure(); @@ -1289,7 +1289,7 @@ class LowerSparseMatMulOp : public RewritePattern { MLIRContext *context = rewriter.getContext(); llvm::SmallVector operands{op.getA(), op.getB()}; for (Value &operand : operands) { - TensorType tensor_type = operand.getType().cast(); + TensorType tensor_type = mlir::cast(operand.getType()); Type element_type = tensor_type.getElementType(); if (element_type.isF32()) continue; // Element type can either be f32 or bf16 for `SparseMatMulOp` so it @@ -1374,13 +1374,13 @@ class LowerResizeNearestNeighbor : public RewritePattern { PatternRewriter &rewriter) const override { auto op = cast(src_op); auto loc = op.getLoc(); - auto result_ty = op.getType().cast(); + auto result_ty = mlir::cast(op.getType()); auto input = op.getImages(); - auto input_ty = input.getType().cast(); + auto input_ty = mlir::cast(input.getType()); auto input_element_ty = input_ty.getElementType(); auto out_size = op.getSize(); - auto out_size_ty = out_size.getType().cast(); + auto out_size_ty = mlir::cast(out_size.getType()); auto out_size_element_ty = out_size_ty.getElementType(); // Input should be rank 4. @@ -1620,7 +1620,7 @@ struct LowerRollOp : public RewritePattern { auto tf_roll_op = cast(op); auto input_ty = - tf_roll_op.getInput().getType().dyn_cast(); + mlir::dyn_cast(tf_roll_op.getInput().getType()); if (!input_ty || !input_ty.hasStaticShape()) { return rewriter.notifyMatchFailure( op, "require the type of input to have static shapes"); @@ -1628,7 +1628,8 @@ struct LowerRollOp : public RewritePattern { DenseIntElementsAttr shift_attr; Value shift = tf_roll_op.getShift(); - auto shift_ranked_attr_type = shift.getType().dyn_cast(); + auto shift_ranked_attr_type = + mlir::dyn_cast(shift.getType()); if (!shift_ranked_attr_type || !matchPattern(shift, m_Constant(&shift_attr))) { return failure(); @@ -1636,7 +1637,8 @@ struct LowerRollOp : public RewritePattern { DenseIntElementsAttr axis_attr; Value axis = tf_roll_op.getAxis(); - auto axis_ranked_attr_type = axis.getType().dyn_cast(); + auto axis_ranked_attr_type = + mlir::dyn_cast(axis.getType()); if (!axis_ranked_attr_type || !matchPattern(axis, m_Constant(&axis_attr))) { return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index cd608bdf269ad7..80e7cd3991c727 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/utils/validators.h" @@ -57,13 +58,14 @@ class SimplifyBroadcastReshape : public OpRewritePattern { auto reshape_op = llvm::dyn_cast_or_null(user); if (!reshape_op) return failure(); - auto reshape_type = reshape_op.getOutput().getType().cast(); + auto reshape_type = + mlir::cast(reshape_op.getOutput().getType()); if (!reshape_type.hasStaticShape()) return failure(); ArrayRef reshape_shape = reshape_type.getShape(); - auto input_type = op.getInput().getType().cast(); - auto output_type = op.getOutput().getType().cast(); + auto input_type = mlir::cast(op.getInput().getType()); + auto output_type = mlir::cast(op.getOutput().getType()); if (!input_type.hasRank() || !output_type.hasRank()) return failure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index eaf881c43df95e..bfed05448bd25a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -94,7 +94,7 @@ GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) { continue; } auto global_tensor = symbol_table.lookup( - sym.cast().getValue()); + mlir::cast(sym).getValue()); if (!global_tensor) { continue; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 3a2ba6f181f649..74f25f90dedf33 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_ +#include #include #include @@ -99,7 +100,8 @@ std::unique_ptr> CreateReplicateTensorListInitOpsPass(); // Performs Shape Inference on the TensorFlow dialect using the global registry. -std::unique_ptr> CreateTFShapeInferencePass(); +std::unique_ptr> CreateTFShapeInferencePass( + ArrayRef> input_shapes = {}); // Performs TF.data optimizations. std::unique_ptr> CreateTFDataOptimizationPass(); @@ -308,9 +310,8 @@ std::unique_ptr> CreateNameAnonymousIteratorsPass(); // Creates a pass that breaks up an island with multiple ops into multiple // islands, each with a single op. This pass intentionally does not propagate -// control dependencies across newly created islands, a following pass will -// handle this. -// TODO(b/244596254) Implement followup pass for creating control deps. +// control dependencies across newly created islands and is handled by +// CreateTFExecutorUpdateControlDependenciesPass. std::unique_ptr> CreateSplitIntoIslandPerOpPass(); // Prints, but otherwise pipes through without changes, the current module. @@ -531,10 +532,6 @@ CreateTPUResourceReadsWritesPartitioningPass(); std::unique_ptr> CreateTPUAnnotateDynamicShapeInputsPass(); -// Creates a pass that identifies XLASharding ops in launch op for TPU -// computation. -std::unique_ptr> CreateTPUShardingIdentificationPass(); - // Creates a pass that moves `tf.AssignVariableOp` into a // `tf_device.parallel_execute` region if the `tf.AssignVariableOp` is the // only consumer of a `tf_device.parallel_execute` result. @@ -668,7 +665,6 @@ enum MoveTransposeDirection { kBegin, kEnd }; #define GEN_PASS_DECL_TPUREORDERREPLICATEANDPARTITIONEDINPUTSPASS #define GEN_PASS_DECL_TPURESOURCEREADFORWRITEPASS #define GEN_PASS_DECL_TPURESOURCEREADSWRITESPARTITIONINGPASS -#define GEN_PASS_DECL_TPUSHARDINGIDENTIFICATIONPASS #define GEN_PASS_DECL_TPUSPACETODEPTHPASS #define GEN_PASS_DECL_TPUUPDATEEMBEDDINGENQUEUEOPINPUTSPASS #define GEN_PASS_DECL_TPUVALIDATEINPUTSPASS diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc index 661dafe2a2f327..b968923089cb8f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" @@ -79,8 +80,8 @@ class RewriteXlaHostComputeMlir llvm::SmallVector shape_attrs; shape_attrs.reserve(op.getNumResults()); for (Type ty : op.getResultTypes()) { - shape_attrs.push_back( - TF::ShapeAttr::get(rewriter.getContext(), ty.cast())); + shape_attrs.push_back(TF::ShapeAttr::get(rewriter.getContext(), + mlir::cast(ty))); } // Clone the `host_func` in the `host_mlir_module` attribute if it exists diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index a7226b39ebe380..bc64c48c81a596 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -196,8 +197,8 @@ LogicalResult PromoteResourcesToArguments( auto func_args = function.getArguments().take_front( function.getNumArguments() - var_handle_shared_names.size()); for (BlockArgument& func_arg : func_args) { - auto resource_type = - getElementTypeOrSelf(func_arg.getType()).dyn_cast(); + auto resource_type = mlir::dyn_cast( + getElementTypeOrSelf(func_arg.getType())); if (!resource_type) continue; if (failed(ValidateResourceArgument(function, func_arg, resource_type))) return failure(); @@ -212,8 +213,8 @@ LogicalResult PromoteResourcesToArguments( auto var_handle_args = function.getArguments().take_back(var_handle_shared_names.size()); for (BlockArgument& var_handle_arg : var_handle_args) { - auto resource_type = - getElementTypeOrSelf(var_handle_arg.getType()).cast(); + auto resource_type = mlir::cast( + getElementTypeOrSelf(var_handle_arg.getType())); add_resource_argument(var_handle_arg, resource_type); } @@ -226,7 +227,8 @@ LogicalResult PromoteResourcesToArguments( // live value. for (Operation& op : llvm::make_early_inc_range(block)) { if (auto read_op = llvm::dyn_cast(&op)) { - if (auto func_arg = read_op.getResource().dyn_cast()) { + if (auto func_arg = + mlir::dyn_cast(read_op.getResource())) { if (func_arg.getOwner() != &block) return read_op.emitOpError(kResourceFunctionMsg); @@ -239,7 +241,8 @@ LogicalResult PromoteResourcesToArguments( read_op.erase(); } else if (auto write_op = llvm::dyn_cast(&op)) { - if (auto func_arg = write_op.getResource().dyn_cast()) { + if (auto func_arg = + mlir::dyn_cast(write_op.getResource())) { if (func_arg.getOwner() != &block) return write_op.emitOpError(kResourceFunctionMsg); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc index 975a1484d6984a..7c488b8992d2cb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -91,7 +92,7 @@ StringRef GetNodeNameFromClassAttrOrSharedNameAttr(Operation *op) { StringRef result; for (Attribute class_attr : classes_attr) { - StringRef node_name = class_attr.cast().getValue(); + StringRef node_name = mlir::cast(class_attr).getValue(); if (!node_name.starts_with(kLocationPrefix)) { continue; } @@ -150,8 +151,8 @@ void ConvertReadonlyReferenceVariablesToResourceVariablesPass:: for (VariableV2Op variable_v2_op : variable_v2s_to_replace) { builder.setInsertionPoint(variable_v2_op); ShapedType shaped_type = - variable_v2_op.getResult().getType().cast(); - TensorType tensor_type = DropRefType(shaped_type).cast(); + mlir::cast(variable_v2_op.getResult().getType()); + TensorType tensor_type = mlir::cast(DropRefType(shaped_type)); StringAttr device_attr = variable_v2_op->getAttrOfType("device"); if (!device_attr) device_attr = builder.getStringAttr(""); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc index a669276e35a175..b740e667dabe84 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc @@ -508,10 +508,10 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp( // existing function as is. auto while_arg_matcher = [](Value first, Region& first_region, Value second, Region& second_region) { - if (!first.isa() || !second.isa()) + if (!mlir::isa(first) || !mlir::isa(second)) return false; - BlockArgument first_block_arg = first.cast(); - BlockArgument second_block_arg = second.cast(); + BlockArgument first_block_arg = mlir::cast(first); + BlockArgument second_block_arg = mlir::cast(second); // 2 block arguments will match if they are the same argument number, and // are block arguments of the corresponding containing regions. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc b/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc index 6aa3d161c0e121..18f54d6b5826d3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project #include "mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TF { @@ -137,7 +138,7 @@ void RemoveUnusedArgumentsPass::runOnOperation() { // SymbolUserOpInterface doesn't tell us which attributes contain // the symbols, so we have to scan through all of them. for (auto attr : op->getAttrs()) { - if (auto sym = attr.getValue().dyn_cast()) { + if (auto sym = mlir::dyn_cast(attr.getValue())) { Operation* func = mlir::SymbolTable::lookupNearestSymbolFrom(op, sym); if (func) { do_not_touch.insert(func); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc b/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc index 9e85c5f9ed6fda..3a6377a3bb63e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc @@ -55,7 +55,7 @@ void RecursiveRemove(Operation* op, erase_list.push_back(op); for (auto& use : op->getOpOperands()) { - if (auto op_result = use.get().dyn_cast()) { + if (auto op_result = mlir::dyn_cast(use.get())) { Operation* def = op_result.getDefiningOp(); if (!dead_ops.insert(def).second) continue; RecursiveRemove(def, erase_list, dead_ops); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc index 1c9558eecda702..803f135af624d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -90,7 +91,7 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, Value input = shape_op.getInput(); // If ShapeOp operand is replicate tensor block argument, replace with the // associated first replica operand. - if (auto block_arg = input.dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(input)) { if (block_arg.getOwner() != replicate_block) return; shape_op.setOperand(replicate_op.GetReplicaOperandForBlockArgument( @@ -112,7 +113,8 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, // shape has not changed in replicate prior to read. Currently after both // ResourceOpLiftingPass and TPURewritePass, there should not be any updates // to resources prior to their respective ReadVariableOp. - if (auto block_arg = read_var_op.getResource().dyn_cast()) { + if (auto block_arg = + mlir::dyn_cast(read_var_op.getResource())) { if (block_arg.getOwner() != replicate_block) return; OpBuilder builder(shape_op); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index e03eb9a9228f35..90397e7f8237c9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -74,14 +74,14 @@ struct ResourceOpLiftingPass }; bool IsResource(Value value) { - return getElementTypeOrSelf(value.getType()).isa(); + return mlir::isa(getElementTypeOrSelf(value.getType())); } // Get the type of the data contained in a resource. Returns null if there is // no single type in the resource. Type GetResourceSubtype(Value value) { auto resource_type = - getElementTypeOrSelf(value.getType()).dyn_cast(); + mlir::dyn_cast(getElementTypeOrSelf(value.getType())); auto subtypes = resource_type.getSubtypes(); if (subtypes.size() == 1) return subtypes[0]; return nullptr; @@ -691,7 +691,7 @@ void RemoveUnusedResourceArgumentsAndForwardedRetvals( int64_t skipped_retvals = 0; for (auto entry : llvm::enumerate(old_return_vals)) { auto return_val = entry.value(); - if (auto arg = return_val.dyn_cast()) { + if (auto arg = mlir::dyn_cast(return_val)) { auto it = infos.find(arg.getArgNumber()); if (it != infos.end() && !it->getSecond().used) { return_op->eraseOperand(entry.index() - skipped_retvals++); @@ -747,7 +747,7 @@ LogicalResult LiftArgRetResourcesForFunction( // with type replaced. llvm::SmallVector skipped_args; for (auto& it : hoister.GetResources()) { - BlockArgument arg = it.first.dyn_cast(); + BlockArgument arg = mlir::dyn_cast(it.first); assert(arg && "Expect resources for FuncOp to be its arguments"); auto type_iter = resource_data_types.find(arg.getArgNumber()); if (type_iter == resource_data_types.end()) { @@ -772,7 +772,7 @@ LogicalResult LiftArgRetResourcesForFunction( Value resource = assign_variable_op.getResource(); if (!hoister.Contains(resource)) continue; - auto arg = resource.dyn_cast(); + auto arg = mlir::dyn_cast(resource); handle_updated_arg_value(arg.getArgNumber(), assign_variable_op.getValue()); assign_variable_op.erase(); } @@ -1018,11 +1018,11 @@ LogicalResult HandlePartitionedCallOpCallee( for (auto entry : llvm::enumerate(callee.front().getTerminator()->getOperands())) { auto retval = entry.value(); - if (!getElementTypeOrSelf(retval.getType()).isa()) { + if (!mlir::isa(getElementTypeOrSelf(retval.getType()))) { result->old_to_new_output_indices.push_back(non_resource_results++); continue; } - auto aliasing_arg = retval.dyn_cast(); + auto aliasing_arg = mlir::dyn_cast(retval); if (!aliasing_arg) { return callee.emitOpError("unsupported function call: ") << "resource return value does not alias an input."; @@ -1063,7 +1063,7 @@ LogicalResult HandlePartitionedCallOpCallee( llvm::SmallVector retval_indices_to_preserve; for (auto& val : callee.front().getTerminator()->getOpOperands()) { // Store indices of results that are not resources. - if (!getElementTypeOrSelf(val.get().getType()).isa()) + if (!mlir::isa(getElementTypeOrSelf(val.get().getType()))) retval_indices_to_preserve.push_back(val.getOperandNumber()); } int64_t num_retvals = retval_indices_to_preserve.size(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc index 2f1c675b305516..303e5aa2b6ddeb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -32,7 +33,7 @@ namespace mlir { namespace { bool IsResource(Value value) { - return getElementTypeOrSelf(value.getType()).isa(); + return mlir::isa(getElementTypeOrSelf(value.getType())); } // Checks if a cast op is casting a resource -> resource. @@ -182,7 +183,7 @@ void EliminateUnusedResultsForIfCase(Operation *op, if (cloned == func) continue; // Patch up the op attribute to point to the new function. for (NamedAttribute attr : op->getAttrs()) { - auto symref = attr.getValue().dyn_cast(); + auto symref = mlir::dyn_cast(attr.getValue()); if (!symref) continue; if (symref.getValue() != func.getName()) continue; op->setAttr(attr.getName(), @@ -301,7 +302,8 @@ LogicalResult ForwardCommonArgToOutput(Operation *op, std::optional common_arg_index; for (func::FuncOp func : branches) { auto ret = func.front().getTerminator(); - auto block_arg = ret->getOperand(result_idx).dyn_cast(); + auto block_arg = + mlir::dyn_cast(ret->getOperand(result_idx)); if (!block_arg) { return op->emitOpError("result #") << result_idx << " not tied to function argument for branch @" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc index 2ff6c78896fff2..faedd25114807e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc @@ -37,12 +37,13 @@ struct RewriteTPUEmbeddingOps // Rewrites the given op to `OpT` op after adding the given operand at the end. template -OpT AddOperandAndRewriteAs(Operation* op, Value operand, OpBuilder* builder) { +OpT AddOperandAndRewriteAs(Operation* op, Value operand, NamedAttrList attr, + OpBuilder* builder) { builder->setInsertionPoint(op); auto operands = llvm::to_vector<4>(op->getOperands()); operands.push_back(operand); auto new_op = builder->create(op->getLoc(), op->getResultTypes(), - operands, op->getAttrs()); + operands, attr.getAttrs()); op->replaceAllUsesWith(new_op.getOperation()->getResults()); op->erase(); return new_op; @@ -83,8 +84,8 @@ LogicalResult RunOnRegion(Region* region) { // Rewrite RecvTPUEmbeddingActivations op to the corresponding internal op. if (recv_op) - AddOperandAndRewriteAs(recv_op, dedup_op, - &builder); + AddOperandAndRewriteAs( + recv_op, dedup_op, recv_op->getAttrs(), &builder); // Rewrite SendTPUEmbeddingGradients op to the corresponding internal op and // then update the OperandSegmentSize attribute. @@ -92,11 +93,11 @@ LogicalResult RunOnRegion(Region* region) { int32_t operand_sizes[] = {static_cast(send_op.getN()), static_cast(send_op.getNN()), 1}; auto operand_size_attr = builder.getDenseI32ArrayAttr(operand_sizes); + NamedAttrList attrs(send_op->getAttrs()); + attrs.set(send_op.getOperandSegmentSizeAttr(), operand_size_attr); - auto new_send_op = AddOperandAndRewriteAs( - send_op, dedup_op, &builder); - new_send_op->setAttr(new_send_op.getOperandSegmentSizeAttr(), - operand_size_attr); + AddOperandAndRewriteAs(send_op, dedup_op, + attrs, &builder); } return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h index c86c4383cc602f..4dd6ae7c8e4a7d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TF { @@ -27,13 +28,13 @@ namespace TF { template DenseElementsAttr GetScalarOfType(Type ty, T raw_value) { RankedTensorType scalar_ty = RankedTensorType::get({}, ty); - if (auto float_ty = ty.dyn_cast()) { + if (auto float_ty = mlir::dyn_cast(ty)) { FloatAttr attr = FloatAttr::get(float_ty, raw_value); return DenseElementsAttr::get(scalar_ty, attr); - } else if (auto int_ty = ty.dyn_cast()) { + } else if (auto int_ty = mlir::dyn_cast(ty)) { IntegerAttr attr = IntegerAttr::get(int_ty, raw_value); return DenseElementsAttr::get(scalar_ty, attr); - } else if (auto complex_ty = ty.dyn_cast()) { + } else if (auto complex_ty = mlir::dyn_cast(ty)) { Type complex_element_ty = complex_ty.getElementType(); if (complex_element_ty.isF32()) { return DenseElementsAttr::get( @@ -50,13 +51,13 @@ DenseElementsAttr GetScalarOfType(Type ty, T raw_value) { // to `raw_value`. template bool IsConstantValueOf(Value value, T raw_value) { - auto element_type = value.getType().cast().getElementType(); - if (element_type.isa()) { + auto element_type = mlir::cast(value.getType()).getElementType(); + if (mlir::isa(element_type)) { DenseFPElementsAttr float_attr; if (matchPattern(value, m_Constant(&float_attr)) && float_attr.isSplat() && float_attr.getSplatValue().isExactlyValue(raw_value)) return true; - } else if (element_type.isa()) { + } else if (mlir::isa(element_type)) { DenseIntElementsAttr int_attr; if (matchPattern(value, m_Constant(&int_attr)) && int_attr.isSplat() && int_attr.getSplatValue() == raw_value) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc b/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc index 4948aa68e13039..0eb552208194e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "xla/layout.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -64,27 +65,27 @@ FailureOr GetTPUInfeedLayout(const ArrayRef types, llvm::SmallVector v; v.reserve(types.size()); for (const mlir::Type &t : types) { - if (t.isa()) continue; + if (mlir::isa(t)) continue; auto layout = GetTPUInfeedLayout({t}, rewriter); if (failed(layout)) return failure(); v.push_back(layout.value()); } ArrayRef shape(v); return rewriter.getArrayAttr(shape); - } else if (types[0].isa()) { - auto tuple_type = types[0].dyn_cast(); + } else if (mlir::isa(types[0])) { + auto tuple_type = mlir::dyn_cast(types[0]); const auto &types = tuple_type.getTypes(); llvm::SmallVector v; v.reserve(types.size()); for (const mlir::Type &t : types) { - if (t.isa()) continue; + if (mlir::isa(t)) continue; auto layout = GetTPUInfeedLayout({t}, rewriter); if (failed(layout)) return failure(); v.push_back(layout.value()); } ArrayRef shape(v); return rewriter.getArrayAttr(shape); - } else if (auto t = types[0].dyn_cast()) { + } else if (auto t = mlir::dyn_cast(types[0])) { if (!t.hasStaticShape()) return failure(); auto layout = GetTPUInfeedLayoutFromAPI(t); std::vector minor_to_major; @@ -129,7 +130,7 @@ bool SetTPUInfeedLayout(mlir::OwningOpRef &mlir_module) { std::vector result_types; for (mlir::Type t : op.getResultTypes()) { - auto ty = t.cast(); + auto ty = mlir::cast(t); if (!ty.hasStaticShape()) return mlir::WalkResult::interrupt(); result_types.push_back(t); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index dc1cfe3f5920fe..6a9527aea26b3f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -29,6 +29,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Hashing.h" @@ -75,6 +77,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h" @@ -89,6 +92,7 @@ limitations under the License. #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/ir/types/dialect.h" +#include "tsl/platform/errors.h" #define DEBUG_TYPE "tf-shape-inference" @@ -121,9 +125,9 @@ Type TypeMeet(Type lhs, Type rhs) { DCOMMENT("RefineTypeWith : " << lhs << " : " << rhs); if (lhs == rhs) return lhs; - auto rhs_shape_type = rhs.dyn_cast(); + auto rhs_shape_type = mlir::dyn_cast(rhs); if (!rhs_shape_type) return lhs; - auto lhs_shape_type = lhs.cast(); + auto lhs_shape_type = mlir::cast(lhs); if (lhs_shape_type.hasRank() && rhs_shape_type.hasRank() && lhs_shape_type.getRank() != rhs_shape_type.getRank()) { DCOMMENT("Unexpected rank mismatch: " << lhs << " vs " << rhs); @@ -163,7 +167,8 @@ Type TypeMeet(Type lhs, Type rhs) { // returned type. auto lhs_element_type = lhs_shape_type.getElementType(); auto rhs_element_type_with_subtype = - rhs_shape_type.getElementType().dyn_cast(); + mlir::dyn_cast( + rhs_shape_type.getElementType()); // Look for resource or variant element type and ensure we refine the subtype. // We only support a single subtype at the moment, we won't handle something // like: @@ -171,7 +176,7 @@ Type TypeMeet(Type lhs, Type rhs) { if (rhs_element_type_with_subtype && rhs_element_type_with_subtype.GetSubtypes().size() == 1) { auto lhs_element_type_with_subtype = - lhs_element_type.dyn_cast(); + mlir::dyn_cast(lhs_element_type); TensorType subtype; if (!lhs_element_type_with_subtype) { DCOMMENT( @@ -189,10 +194,9 @@ Type TypeMeet(Type lhs, Type rhs) { // and: // tensor>> // we'll try here to refine tensor with tensor<10x8xf32>. - auto refined_subtype = + auto refined_subtype = mlir::cast( TypeMeet(lhs_element_type_with_subtype.GetSubtypes().front(), - rhs_element_type_with_subtype.GetSubtypes().front()) - .cast(); + rhs_element_type_with_subtype.GetSubtypes().front())); if (refined_subtype != lhs_element_type_with_subtype.GetSubtypes().front()) subtype = refined_subtype; @@ -268,7 +272,7 @@ Value GetElementShapeOperand(Operation* op) { // Utility function to create a ranked tensor type after dropping the first // dimension from the input type. RankedTensorType DropFirstDimension(Type type) { - RankedTensorType ranked_type = type.dyn_cast(); + RankedTensorType ranked_type = mlir::dyn_cast(type); if (!ranked_type) return {}; llvm::ArrayRef dims_except_first = ranked_type.getShape().drop_front(); @@ -278,7 +282,7 @@ RankedTensorType DropFirstDimension(Type type) { Operation* InsertCast(OpBuilder& b, Location loc, Type dst_type, Value input) { Type element_type = getElementTypeOrSelf(dst_type); - if (element_type.isa()) + if (mlir::isa(element_type)) return b.create(loc, dst_type, input); if (isa(element_type.getDialect())) return b.create(loc, dst_type, input, @@ -338,7 +342,7 @@ bool CanInferTensorListElementType(Value tensorlist, for (auto& use : tensorlist.getUses()) { if (auto push = llvm::dyn_cast(use.getOwner())) { auto element_type = - push.getTensor().getType().dyn_cast(); + mlir::dyn_cast(push.getTensor().getType()); if (!verify_and_update_potential_element_type(element_type)) return false; add_to_worklist(push.getOutputHandle()); @@ -357,7 +361,7 @@ bool CanInferTensorListElementType(Value tensorlist, } if (auto set_item = llvm::dyn_cast(use.getOwner())) { auto element_type = - set_item.getItem().getType().dyn_cast(); + mlir::dyn_cast(set_item.getItem().getType()); DCOMMENT("\tTensorListSetItemOp " << element_type); if (!verify_and_update_potential_element_type(element_type)) return false; @@ -429,8 +433,8 @@ bool CanInferTensorListElementType(Value tensorlist, // Returns the tensor type created from the `shape_attr` and `type_attr` // attributes. Type GetType(Attribute shape_attr, Attribute type_attr) { - auto shape = shape_attr.cast(); - auto type = type_attr.cast(); + auto shape = mlir::cast(shape_attr); + auto type = mlir::cast(type_attr); if (shape.hasRank()) return tensorflow::GetTypeFromTFTensorShape(shape.getShape(), type.getValue()); @@ -441,7 +445,7 @@ Type GetType(Attribute shape_attr, Attribute type_attr) { // Returns whether type can be further refined. bool CanBeRefined(Type type) { - auto shape_type = type.dyn_cast(); + auto shape_type = mlir::dyn_cast(type); if (!shape_type) return false; // Returns whether type with subtypes can be further refined. @@ -449,8 +453,8 @@ bool CanBeRefined(Type type) { return tws.GetSubtypes().empty() || llvm::any_of(tws.GetSubtypes(), CanBeRefined); }; - auto type_with_subtype = - shape_type.getElementType().dyn_cast(); + auto type_with_subtype = mlir::dyn_cast( + shape_type.getElementType()); if (type_with_subtype && can_refine_subtypes(type_with_subtype)) return true; return !shape_type.hasStaticShape(); @@ -463,7 +467,7 @@ Type GetNewArgType(Type old_arg_type, ArrayRef shape, Type element_type, mlir::MLIRContext* context) { Type new_arg_type = tensorflow::GetTypeFromTFTensorShape(shape, element_type); - if (auto input_ty = old_arg_type.dyn_cast()) { + if (auto input_ty = mlir::dyn_cast(old_arg_type)) { ArrayRef bounds = hlo::encodingToBounds(input_ty.getEncoding()); // The input type has bounded dynamic dimension. if (!bounds.empty()) { @@ -501,12 +505,12 @@ struct ValuePort { // Convert output value to ValuePort. explicit ValuePort(Value v) { - OpResult opr = v.dyn_cast(); + OpResult opr = mlir::dyn_cast(v); if (opr) { producer = opr.getOwner(); port = {opr.getResultNumber()}; } else { - producer = v.cast(); + producer = mlir::cast(v); port = {0}; } } @@ -545,7 +549,7 @@ using ValuePortInputs = SmallVectorImpl; // Maps the specified component in the `port` of the given op's result to one of // the element in the input. ValuePort ComputeInputComponentFor(PackOp op, ArrayRef port) { - auto type = op.getType().cast(); + auto type = mlir::cast(op.getType()); if (!type.hasRank() || type.getRank() != 1) return {}; if (port.size() != 2) return {}; assert(port[0] == 0); @@ -558,7 +562,7 @@ ValuePort ComputeInputComponentFor(ConcatV2Op op, ArrayRef port) { int64_t element_idx = port[1]; for (Value val : op.getValues()) { - auto val_ty = val.getType().cast(); + auto val_ty = mlir::cast(val.getType()); if (!val_ty.hasStaticShape() || val_ty.getRank() != 1) return {}; int64_t dim_size = val_ty.getNumElements(); @@ -579,7 +583,7 @@ ValuePort ComputeInputComponentFor(GatherV2Op op, ArrayRef port) { assert(port[0] == 0); auto params = op.getParams(); - auto params_ty = params.getType().dyn_cast(); + auto params_ty = mlir::dyn_cast(params.getType()); if (!params_ty || !params_ty.hasStaticShape() || params_ty.getRank() != 1 || op.getBatchDims() != 0) { return {}; @@ -683,7 +687,7 @@ Attribute ComputeOutputComponent(const ValuePort& value_port, if (auto shape_op = dyn_cast(op)) { // No shape available in an unranked tensor type. auto operand_ty = - shape_op.getOperand().getType().dyn_cast(); + mlir::dyn_cast(shape_op.getOperand().getType()); if (!operand_ty) return nullptr; // Shape op has a single output so the first element should always be zero @@ -1130,14 +1134,14 @@ bool ShapeInference::InferShapeForCast(Operation* op) { if (!new_type) { // Combine shape information when leaf element types are not the same, not // including shape info in subtypes. - auto ranked_operand_type = operand_type.dyn_cast(); + auto ranked_operand_type = mlir::dyn_cast(operand_type); if (!ranked_operand_type) return false; - auto ranked_res_type = result.getType().dyn_cast(); + auto ranked_res_type = mlir::dyn_cast(result.getType()); if (ranked_res_type && ranked_operand_type.getShape() == ranked_res_type.getShape()) return false; - auto shaped_res_type = result_type.dyn_cast(); + auto shaped_res_type = mlir::dyn_cast(result_type); if (!shaped_res_type) return false; new_type = tensorflow::GetTypeFromTFTensorShape( ranked_operand_type.getShape(), shaped_res_type.getElementType()); @@ -1292,7 +1296,7 @@ bool ShapeInference::InferShapeForXlaCallModule(XlaCallModuleOp op) { int next_op_result = 0; for (auto output_type : main_output_types) { if (tensorflow::IsTokenType(output_type)) continue; - auto output_type_ranked = output_type.dyn_cast(); + auto output_type_ranked = mlir::dyn_cast(output_type); if (output_type_ranked == nullptr) { llvm::errs() << "Unsupported XlaCallModule result type: " << output_type << "\n"; @@ -1418,20 +1422,20 @@ bool ShapeInference::InferShapeForRestore(Operation* op) { if (!assign_op) { continue; } - auto subtypes = getElementTypeOrSelf(assign_op.getResource()) - .cast() + auto subtypes = mlir::cast( + getElementTypeOrSelf(assign_op.getResource())) .getSubtypes(); if (subtypes.empty()) { continue; } - auto subtype = subtypes.front().dyn_cast(); + auto subtype = mlir::dyn_cast(subtypes.front()); if (subtype == nullptr) { continue; } // Preserve the dtype from the restore op even if `AssignVariableOp` uses a // different dtype, which is possible when there's a `CastOp` between them. subtype = subtype.clone( - op->getResult(0).getType().cast().getElementType()); + mlir::cast(op->getResult(0).getType()).getElementType()); // Update the result type of this op with the resource's type. We only use // the resource subtype of the first user since shapes from all the users // should be equal or compatible. @@ -1456,7 +1460,7 @@ DatasetInput GetDatasetInput(Value value) { while ( llvm::isa_and_nonnull(value.getDefiningOp())) { value = value.getDefiningOp()->getOperand( - value.cast().getResultNumber()); + mlir::cast(value).getResultNumber()); } Operation* op = value.getDefiningOp(); @@ -1664,14 +1668,14 @@ bool ShapeInference::InferShapeForTensorListPopBackOp(TensorListPopBackOp op) { DCOMMENT_OP(op, "Inferring shape for TensorListPopBackOp."); auto src_list_handle_t = - op.getOperand(0).getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op.getOperand(0).getType()); if (!src_list_handle_t) return false; // Copy of operand tensorlist type. TensorType dst_list_handle_t = src_list_handle_t.clone(src_list_handle_t.getElementType()); auto variant_element_t = - dst_list_handle_t.getElementType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(dst_list_handle_t.getElementType()); if (!variant_element_t || variant_element_t.getSubtypes().size() != 1) return false; @@ -1722,7 +1726,7 @@ bool ShapeInference::InferShapeForVarHandleOp(VarHandleOp op) { llvm_unreachable("unexpected operator type"); } - TensorType resource_subtype = value.getType().cast(); + TensorType resource_subtype = mlir::cast(value.getType()); ResourceType resource_type = ResourceType::get({resource_subtype}, op.getContext()); UnrankedTensorType new_resource_type = @@ -1854,7 +1858,7 @@ bool ShapeInference::InferShapeForXlaReduceWindowOp(XlaReduceWindowOp op) { bool changed = false; - auto input_ty = op.getInput().getType().cast(); + auto input_ty = mlir::cast(op.getInput().getType()); DenseElementsAttr window_dimensions, window_strides, base_dilations, window_dilations, padding; if (input_ty.hasStaticShape() && @@ -1901,7 +1905,7 @@ bool ShapeInference::InferShapeForXlaReduceWindowOp(XlaReduceWindowOp op) { } auto output_shape = InferWindowOutputShape( input_ty, window.value(), - op.getInitValue().getType().cast().getElementType()); + mlir::cast(op.getInitValue().getType()).getElementType()); if (!output_shape) { op->emitOpError("failed to infer output shape"); @@ -1918,8 +1922,8 @@ bool ShapeInference::InferShapeForXlaSelectAndScatterOp( XlaSelectAndScatterOp op) { DCOMMENT_OP(op, "Inferring shape for XlaSelectAndScatterOp"); - auto operand_shape = op.getOperand().getType().cast(); - auto source_shape = op.getSource().getType().cast(); + auto operand_shape = mlir::cast(op.getOperand().getType()); + auto source_shape = mlir::cast(op.getSource().getType()); DenseElementsAttr window_dimensions, window_strides, padding; if (operand_shape.hasRank() && source_shape.hasRank() && matchPattern(op.getWindowDimensions(), m_Constant(&window_dimensions)) && @@ -2081,13 +2085,14 @@ LogicalResult PrecheckForXlaConvV2Op(XlaConvV2Op op) { int64_t batch_group_count = op.getBatchGroupCount(); auto input_args_have_static_shape = [&]() -> bool { - return input_tensor.getType().cast().hasStaticShape() && - kernel_tensor.getType().cast().hasStaticShape() && - window_strides.getType().cast().hasStaticShape() && - padding.getType().cast().hasStaticShape() && - lhs_dilation.getType().cast().hasStaticShape() && - rhs_dilation.getType().cast().hasStaticShape() && - feature_group_count.getType().cast().hasStaticShape(); + return mlir::cast(input_tensor.getType()).hasStaticShape() && + mlir::cast(kernel_tensor.getType()).hasStaticShape() && + mlir::cast(window_strides.getType()).hasStaticShape() && + mlir::cast(padding.getType()).hasStaticShape() && + mlir::cast(lhs_dilation.getType()).hasStaticShape() && + mlir::cast(rhs_dilation.getType()).hasStaticShape() && + mlir::cast(feature_group_count.getType()) + .hasStaticShape(); }; // Return failure when one of the input args has not a static shape @@ -2096,9 +2101,9 @@ LogicalResult PrecheckForXlaConvV2Op(XlaConvV2Op op) { } auto input_tensor_shape = - input_tensor.getType().cast().getShape(); + mlir::cast(input_tensor.getType()).getShape(); auto kernel_tensor_shape = - kernel_tensor.getType().cast().getShape(); + mlir::cast(kernel_tensor.getType()).getShape(); if (input_tensor_shape.size() <= 2) { return op.emitOpError() @@ -2225,14 +2230,16 @@ bool ShapeInference::InferShapeForXlaConvV2Op(XlaConvV2Op op) { xla::ConvolutionDimensionNumbers dnums; dnums.ParseFromString(op.getDimensionNumbersAttr().getValue().str()); - auto input_tensor_shape = input_tensor.getType().cast(); + auto input_tensor_shape = + mlir::cast(input_tensor.getType()); for (auto i = 0; i < input_tensor_shape.getShape().size(); ++i) { DCOMMENT("Input Tensor Shape " << i << "th is " << input_tensor_shape.getShape()[i]); input_tensor_dims_vec.push_back(input_tensor_shape.getShape()[i]); } - auto kernel_tensor_shape = kernel_tensor.getType().cast(); + auto kernel_tensor_shape = + mlir::cast(kernel_tensor.getType()); for (auto i = 0; i < kernel_tensor_shape.getShape().size(); ++i) { DCOMMENT("Kernel tensor Shape" << i << "th is " << kernel_tensor_shape.getShape()[i]); @@ -2315,7 +2322,7 @@ bool ShapeInference::RefineWithInferTypeOpInterface( ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, InferenceContext* ic) { LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially ")); - auto rt = result.getType().dyn_cast(); + auto rt = mlir::dyn_cast(result.getType()); if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {}; int dim_size = rt.getDimSize(0); @@ -2362,7 +2369,7 @@ ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, // If worklist is empty, then this is the root query op. if (worklist.empty()) { LLVM_DEBUG(llvm::dbgs() << "[root node]\n"); - if (auto dea = ret.dyn_cast()) { + if (auto dea = mlir::dyn_cast(ret)) { if (dea.getNumElements() != 1) { LLVM_DEBUG(llvm::dbgs() << "Unexpected number of elements\n"); return {}; @@ -2400,7 +2407,7 @@ bool ShapeInference::RefineTypeForPassThroughOperands(Operation* op, for (auto entry : llvm::zip(operands, results)) { Type operand_type = std::get<0>(entry).getType(); Value result = std::get<1>(entry); - TensorType result_type = result.getType().cast(); + TensorType result_type = mlir::cast(result.getType()); Type inferred_type = TypeMeet(result_type, operand_type); if (result_type == inferred_type) continue; @@ -2466,10 +2473,10 @@ bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) { Type GetElementTypeFromOperand(TensorType operand_type, TensorType result_type) { auto operand_handle_type = - operand_type.getElementType().dyn_cast(); + mlir::dyn_cast(operand_type.getElementType()); if (!operand_handle_type) return result_type.getElementType(); auto result_handle_type = - result_type.getElementType().cast(); + mlir::cast(result_type.getElementType()); if (operand_handle_type.GetSubtypes().empty() || !result_handle_type.GetSubtypes().empty()) return result_type.getElementType(); @@ -2505,9 +2512,8 @@ bool ShapeInference::InferShapeForWhile(WhileOpTy op, for (auto entry : zip(op.getInput().getTypes(), op.getOutput(), body_result_types)) { Value result = std::get<1>(entry); - TensorType body_result_type = - std::get<2>(entry).template cast(); - auto result_type = result.getType().cast(); + TensorType body_result_type = mlir::cast(std::get<2>(entry)); + auto result_type = mlir::cast(result.getType()); Type potential_refined_type; if (CanWhileTypeBeRefinedWith(result_type, body_result_type)) { @@ -2518,7 +2524,7 @@ bool ShapeInference::InferShapeForWhile(WhileOpTy op, : std::optional>(), element_type); } else { - TensorType operand_type = std::get<0>(entry).template cast(); + TensorType operand_type = mlir::cast(std::get<0>(entry)); Type element_type = GetElementTypeFromOperand(operand_type, result_type); potential_refined_type = CreateTensorType( result_type.hasRank() ? result_type.getShape() @@ -2671,7 +2677,8 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op, // Return result element type at `index`. auto result_element_type_fn = [&](int index) { - return op->getResult(index).getType().cast().getElementType(); + return mlir::cast(op->getResult(index).getType()) + .getElementType(); }; llvm::SmallVector inferred_return_shapes; @@ -2698,7 +2705,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op, inferred_type = UnrankedTensorType::get(inferred.getElementType()); } inferred_type = - TypeMeet(op_result.getType(), inferred_type).cast(); + mlir::cast(TypeMeet(op_result.getType(), inferred_type)); if (op_result.getType() == inferred_type) continue; if (!UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, op_result)) continue; @@ -2879,19 +2886,19 @@ llvm::SmallVector GetWhileCompatibleTypes( types.reserve(operand_types.size()); for (auto entry : llvm::zip(operand_types, result_types, region_argument_types)) { - auto operand_type = std::get<0>(entry).cast(); - auto result_type = std::get<1>(entry).cast(); + auto operand_type = mlir::cast(std::get<0>(entry)); + auto result_type = mlir::cast(std::get<1>(entry)); if (operand_type == result_type) { types.push_back(operand_type); } else if (RankedAndSameRank(operand_type, result_type)) { - auto potential_refined_type = - GetCompatibleRankedTensorType(operand_type.cast(), - result_type.cast()); + auto potential_refined_type = GetCompatibleRankedTensorType( + mlir::cast(operand_type), + mlir::cast(result_type)); types.push_back(potential_refined_type); } else { - auto region_argument_type = std::get<2>(entry).cast(); + auto region_argument_type = mlir::cast(std::get<2>(entry)); Type element_type = GetElementTypeFromOperand( - operand_type.cast(), region_argument_type); + mlir::cast(operand_type), region_argument_type); Type potential_refined_type = CreateTensorType( region_argument_type.hasRank() ? region_argument_type.getShape() : std::optional>(), @@ -3064,7 +3071,7 @@ LogicalResult ShapeInference::TryToFold(Operation* op) { } } - if (ElementsAttr eattr = attr.dyn_cast_or_null()) { + if (ElementsAttr eattr = mlir::dyn_cast_or_null(attr)) { if (std::get<0>(result).getType() == eattr.getType()) continue; (void)UpdateTypeAndInsertIncompatibleUseCasts(eattr.getType(), @@ -3224,6 +3231,25 @@ static FailureOr InferShapeForFunction(ShapeInference& context, return true; } +absl::StatusOr>> ParseArgumentShapes( + absl::string_view input_shapes) { + SmallVector> parsed_shapes; + if (input_shapes.empty()) { + return parsed_shapes; + } + + std::vector>> shapes; + TF_RETURN_IF_ERROR(::tensorflow::ParseNodeShapes(input_shapes, shapes)); + + for (const auto& shape : shapes) { + if (!shape) { + return absl::AbortedError("Missing input argument shapes"); + } + parsed_shapes.push_back(SmallVector(shape->begin(), shape->end())); + } + return parsed_shapes; +} + FailureOr InferShapeForFunction(func::FuncOp func, ArrayRef> arg_shapes, int64_t graph_version, @@ -3245,13 +3271,15 @@ FailureOr InferShapeForFunction(func::FuncOp func, for (size_t i = 0; i < func_type.getNumInputs(); ++i) { ArrayRef shape = arg_shapes[i]; Type element_type; - if (auto input_ty = func_type.getInput(i).dyn_cast()) { + if (auto input_ty = + mlir::dyn_cast(func_type.getInput(i))) { if (input_ty.getRank() != shape.size()) { return failure(); } element_type = input_ty.getElementType(); } else { - auto unranked_input_ty = func_type.getInput(i).dyn_cast(); + auto unranked_input_ty = + mlir::dyn_cast(func_type.getInput(i)); if (!unranked_input_ty) { return failure(); } @@ -3284,7 +3312,8 @@ FailureOr InferShapeForFunction(func::FuncOp func, } FailureOr InferModuleShape(ModuleOp module, int64_t max_iterations, - ArrayRef ops_to_skip) { + ArrayRef ops_to_skip, + ArrayRef> input_shapes) { auto producer_or = tensorflow::GetTfGraphProducerVersion(module); if (!producer_or.ok()) { // TODO(jpienaar): Keeping the existing behavior for now but this could @@ -3294,13 +3323,30 @@ FailureOr InferModuleShape(ModuleOp module, int64_t max_iterations, return true; } int64_t producer = producer_or.value(); + // TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if // it is no longer needed. ShapeInference context(producer, module, /*propagate_caller_callee_constants=*/false, ops_to_skip); - if (auto main = module.lookupSymbol("main")) + auto main = module.lookupSymbol("main"); + // Error if no main to refine with input shapes + if (!main && !input_shapes.empty()) { + return module->emitError( + "Input shapes provided but no `main` function found."); + } + + // Add main function to head of queue, refine input shapes if provided + if (main) { + if (!input_shapes.empty()) { + FailureOr failure_or_converged = + InferShapeForFunction(main, input_shapes, producer, + /*max_iterations=*/10, ops_to_skip); + if (failed(failure_or_converged) || !failure_or_converged.value()) + return failure_or_converged; + } context.enqueue(main); + } for (auto func : module.getOps()) context.enqueue(func); // Arbitrarily upper bound the maximum number of functions that get processed // just to avoid pathological cases. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h index bc1cf7b3c8f475..46c1bc9c00e55a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -42,8 +44,18 @@ Type GetNewArgType(Type old_arg_type, ArrayRef shape, // whose type is in ops_to_skip. // Returns a failure() on error, otherwise returns true to indicate that it // reached convergence, false otherwise. +// If input shapes are provided, first refines the `main` function using +// InferShapeForFunction. FailureOr InferModuleShape(ModuleOp module, int64_t max_iterations = 10, - ArrayRef ops_to_skip = {}); + ArrayRef ops_to_skip = {}, + ArrayRef> input_shapes = {}); + +// Given a tensorflow NodeShape string, returns a vector of argument shapes +// that can be used with InferShapeForFunction. +// TF NodeShape uses `,` to separate dimensions, and `:` to separate arguments. +// Ex: 1,2:3,4,5:6,? --> [[1, 2], [3, 4, 5], [6, ?]] +absl::StatusOr>> ParseArgumentShapes( + absl::string_view input_shapes); // Given a list of refined shapes matching the function arguments of func, runs // shape inference over the function to propagate this updated information, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc index 37bcb46b95cc57..392b7807b0d418 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc @@ -13,15 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/ArrayRef.h" #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" -#include "tensorflow/core/framework/shape_inference.h" namespace mlir { namespace TF { @@ -36,9 +38,26 @@ namespace { class ShapeInference : public impl::TensorFlowShapeInferencePassBase { public: + ShapeInference() = default; + explicit ShapeInference(ArrayRef> input_shapes) + : input_shapes_(input_shapes) {} void runOnOperation() override { - auto failure_or_converged = - InferModuleShape(getOperation(), max_iterations_, /*ops_to_skip=*/{}); + // Parse `input_arg_shapes_` if provided (test only) + SmallVector> input_shapes_vec; + absl::StatusOr>> parsed_shapes; + if (!input_arg_shapes_.empty()) { + parsed_shapes = ParseArgumentShapes(input_arg_shapes_); + if (!parsed_shapes.ok()) { + getOperation().emitError() << parsed_shapes.status().message(); + return signalPassFailure(); + } + input_shapes_vec = SmallVector>{parsed_shapes->begin(), + parsed_shapes->end()}; + input_shapes_ = input_shapes_vec; + } + + auto failure_or_converged = InferModuleShape( + getOperation(), max_iterations_, /*ops_to_skip=*/{}, input_shapes_); if (failed(failure_or_converged)) return signalPassFailure(); if (!failure_or_converged.value()) { getOperation().emitError() @@ -47,11 +66,15 @@ class ShapeInference return signalPassFailure(); } } + + private: + ArrayRef> input_shapes_; }; } // namespace -std::unique_ptr> CreateTFShapeInferencePass() { - return std::make_unique(); +std::unique_ptr> CreateTFShapeInferencePass( + ArrayRef> input_shapes) { + return std::make_unique(input_shapes); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc index e565d50660558c..abef8ee04f2212 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc @@ -174,8 +174,8 @@ namespace TFDevice { namespace { bool IsResourceType(Type val_type) { - if (auto tensor_type = val_type.dyn_cast()) { - if (tensor_type.getElementType().isa()) { + if (auto tensor_type = mlir::dyn_cast(val_type)) { + if (mlir::isa(tensor_type.getElementType())) { return true; } } @@ -588,7 +588,7 @@ void GatherOpsForExtraction(mlir::SetVector* operations, if (predecessors) { for (Value operand : op->getOperands()) { // Stop at the block boundary. - if (operand.isa()) continue; + if (mlir::isa(operand)) continue; Operation* predecessor = operand.getDefiningOp(); if (!operations->contains(predecessor) && @@ -1867,7 +1867,7 @@ void EmbeddingPipeliningPass::runOnOperation() { for (int ret_pos = 0; ret_pos < orig_return_op->getNumOperands(); ++ret_pos) { auto operand = orig_return_op->getOperand(ret_pos); auto def_op = operand.getDefiningOp(); - auto result = operand.dyn_cast(); + auto result = mlir::dyn_cast(operand); if (def_op == non_tpu_caller) { loop_arg_update_map_non_tpu[result.getResultNumber()] = ret_pos; } else if (def_op == core_tpu_caller) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc index 3e41762feb16c2..1e7958660fd8c4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc @@ -314,7 +314,7 @@ void CreateReducedLaunchOp(OpBuilder* builder, Block* old_block, // Handle pass through block arguments. for (OpOperand& operand : original_launch_op.GetBody().getTerminator()->getOpOperands()) { - if (operand.get().isa()) { + if (mlir::isa(operand.get())) { original_launch_op.getResult(operand.getOperandNumber()) .replaceAllUsesWith(operand.get()); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc index 577b374a43847d..b224b723cda50d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc @@ -95,8 +95,8 @@ std::vector GetValueTypes(const InputContainer& input) { } bool IsResourceType(Type val_type) { - if (auto tensor_type = val_type.dyn_cast()) { - if (tensor_type.getElementType().isa()) { + if (auto tensor_type = mlir::dyn_cast(val_type)) { + if (mlir::isa(tensor_type.getElementType())) { return true; } } @@ -139,7 +139,7 @@ void GatherOpsForExtraction(mlir::SetVector* operations, if (predecessors) { for (Value operand : op->getOperands()) { // Stop at the block boundary. - if (operand.isa()) continue; + if (mlir::isa(operand)) continue; Operation* predecessor = operand.getDefiningOp(); if (!operations->contains(predecessor) && diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index fb9848dbaeac47..476a67b496355f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -73,7 +73,7 @@ Type GetSizeVarType(OpBuilder builder) { // forwards the argument. Otherwise, returns -1. int64_t FindAliasedInput(func::FuncOp func, int64_t return_index) { Value return_val = func.front().getTerminator()->getOperand(return_index); - auto maybe_arg = return_val.dyn_cast(); + auto maybe_arg = mlir::dyn_cast(return_val); if (!maybe_arg) return -1; return maybe_arg.getArgNumber(); } @@ -180,8 +180,8 @@ LogicalResult HandleWhileOp( while_op.getLoc(), body.getFunctionType().getInputs(), new_while_operands, while_op->getAttrs()); for (int64_t i = 0; i < while_op.getNumResults(); ++i) { - if (!getElementTypeOrSelf(while_op.getOperand(i).getType()) - .isa()) { + if (!mlir::isa( + getElementTypeOrSelf(while_op.getOperand(i).getType()))) { continue; } int64_t aliased_input = FindAliasedInput(body, i); @@ -233,7 +233,7 @@ LogicalResult HandleIfOp( if_op.getLoc(), then_func.getFunctionType().getResults(), new_if_operands, if_op->getAttrs()); for (auto result : if_op.getResults()) { - if (!getElementTypeOrSelf(result.getType()).isa()) { + if (!mlir::isa(getElementTypeOrSelf(result.getType()))) { continue; } int64_t then_aliased_input = @@ -287,8 +287,8 @@ LogicalResult HandlePartitionedCallOp( const_cast(info.decomposed_callee).getName())); for (int64_t i = 0; i < call.getNumResults(); ++i) { auto result = call.getResult(i); - if (!getElementTypeOrSelf(result.getType()) - .template isa()) { + if (!mlir::isa( + getElementTypeOrSelf(result.getType()))) { continue; } int64_t aliased_input = FindAliasedInput(info.decomposed_callee, i); @@ -328,9 +328,9 @@ LogicalResult HandlePartitionedCallOp( } else { info.decomposed_callee = lowered_callee; for (auto& entry : callee_map) { - info.stack_var_arg_to_size_arg - [entry.getFirst().cast().getArgNumber()] = - entry.getSecond().cast().getArgNumber(); + info.stack_var_arg_to_size_arg[mlir::cast(entry.getFirst()) + .getArgNumber()] = + mlir::cast(entry.getSecond()).getArgNumber(); } if (lowered_callee != callee) { // Add the clone with a new name. @@ -372,7 +372,7 @@ LogicalResult HandleStackV2Op( auto size_var_type = GetSizeVarType(builder); auto var_type = RankedTensorType::get( {}, TF::ResourceType::get( - ArrayRef{buffer.getType().cast()}, + ArrayRef{mlir::cast(buffer.getType())}, stack.getContext())); auto local_var = builder.create( stack.getLoc(), ArrayRef{var_type}, ArrayRef{}); @@ -446,7 +446,8 @@ LogicalResult HandleRegionControlFlowOps( llvm::StringMap* decomposed_partitioned_call_callees) { for (OpOperand& operand : op.getOpOperands()) { - if (getElementTypeOrSelf(operand.get().getType()).isa()) { + if (mlir::isa( + getElementTypeOrSelf(operand.get().getType()))) { return op.emitOpError() << "found unexpected type " << operand.get().getType() << " of operand #" << operand.getOperandNumber() @@ -455,7 +456,7 @@ LogicalResult HandleRegionControlFlowOps( } } for (OpResult result : op.getResults()) { - if (getElementTypeOrSelf(result.getType()).isa()) { + if (mlir::isa(getElementTypeOrSelf(result.getType()))) { return op.emitOpError() << "found unexpected type " << result.getType() << " of result #" << result.getResultNumber() diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc index b18a6a3496649a..267f32daa9f6e6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" @@ -67,7 +68,7 @@ void TensorDeviceCopyConversionPass::runOnOperation() { (isa(def_op))) { return true; } - if (BlockArgument block_arg = arg.dyn_cast()) { + if (BlockArgument block_arg = mlir::dyn_cast(arg)) { // Skip the folding logic if the block argument is not from the function // arguments. This can happen when the argument is from a while loop. if (block_arg.getParentRegion() != &func_op.getRegion()) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index 278ba1f7fdf65b..a9ad31a28461f7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -510,10 +511,10 @@ LogicalResult HandlePartitionedCallOp( } else { info.signature_change = true; for (auto& entry : callee_map) { - auto buffer_arg = entry.getFirst().dyn_cast(); + auto buffer_arg = mlir::dyn_cast(entry.getFirst()); if (!buffer_arg) continue; info.buffer_arg_to_size_arg[buffer_arg.getArgNumber()] = - entry.getSecond().size.cast().getArgNumber(); + mlir::cast(entry.getSecond().size).getArgNumber(); } if (lowered_callee != callee) { // Add the clone with a new name. @@ -549,7 +550,8 @@ LogicalResult GetConstShapeValue(Value shape_value, // return error. LogicalResult GetElementShapeFromResultType( Type type, llvm::SmallVector* shape) { - auto variant_type = getElementTypeOrSelf(type).dyn_cast(); + auto variant_type = + mlir::dyn_cast(getElementTypeOrSelf(type)); if (!variant_type || variant_type.getSubtypes().size() != 1) return failure(); TensorType tensor_type = variant_type.getSubtypes().front(); if (!tensor_type.hasStaticShape()) return failure(); @@ -619,7 +621,7 @@ LogicalResult HandleTensorListFromTensorOp( Value buffer = builder.create( list.getLoc(), ArrayRef{list.getTensor().getType()}, ArrayRef{list.getTensor()}); - auto type = buffer.getType().cast(); + auto type = mlir::cast(buffer.getType()); if (!type.hasStaticShape()) { return list.emitOpError("TensorListFromTensorOp input has unknown shape."); } @@ -733,8 +735,8 @@ LogicalResult HandleTensorListLengthOp( OpBuilder builder(length); if (it->getSecond().fixed) { auto dim = cutil::CreateScalarConst( - length.getInputHandle().getType().cast().getDimSize( - 0), + mlir::cast(length.getInputHandle().getType()) + .getDimSize(0), builder, length.getLoc()); length.getLength().replaceAllUsesWith(dim); } else { @@ -760,7 +762,7 @@ LogicalResult HandleTensorListElementShapeOp( } auto buffer = elem_shape.getInputHandle(); auto result = cutil::GetR1Const( - buffer.getType().cast().getShape().drop_front(), + mlir::cast(buffer.getType()).getShape().drop_front(), OpBuilder(elem_shape), elem_shape.getLoc(), elem_shape.getShapeType().getIntOrFloatBitWidth()); elem_shape.getElementShape().replaceAllUsesWith(result); @@ -792,7 +794,8 @@ LogicalResult HandleTensorListScatterIntoExistingListOp( } auto buffer = scatter.getInputHandle(); OpBuilder builder(scatter); - auto indices_type = scatter.getIndices().getType().cast(); + auto indices_type = + mlir::cast(scatter.getIndices().getType()); if (!indices_type) return scatter.emitOpError("unranked indices shape"); auto shape_type = RankedTensorType::get({2}, builder.getIntegerType(32)); auto shape = builder.create( @@ -874,7 +877,8 @@ LogicalResult DecomposeTensorListOpsInternal( } else if (auto addn = llvm::dyn_cast(&op)) { auto it = buffer_to_size->find(addn.getOperand(0)); if (it != buffer_to_size->end()) { - addn.getSum().setType(addn.getOperand(0).getType().cast()); + addn.getSum().setType( + mlir::cast(addn.getOperand(0).getType())); auto size = it->getSecond(); (*buffer_to_size)[addn.getSum()] = size; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td index 6b53cae7099688..dbe938e01a519b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td @@ -193,56 +193,6 @@ def StackOpsDecompositionPass : Pass<"tf-stack-ops-decomposition", "ModuleOp"> { -def TPUShardingIdentificationPass : Pass<"tf-tpu-sharding-identification", "ModuleOp"> { - let summary = "Identifies and handles inputs/outputs of TPU computation that is " - "sharded across logical cores."; - let constructor = "TFTPU::CreateTPUShardingIdentificationPass()"; - let description = [{ - Bubbles up sharding configuration from `cluster_func` regions into - the attributes of `cluster_func`. This is done by parsing the - `XlaSharding` / `TPUPartitionedOutput` / `TPUPartitionedInput` ops inside - `cluster_func`. - - For example, given the following `cluster_func` wrapping `func`: - - ```mlir - func @test(%arg0: tensor<*xi32>) { - "tf_device.cluster_func"(%arg0) { - func = @func, - step_marker_location = ""} : (tensor<*xi32>) -> tensor<*xi32> - return - } - - func @func(%arg0: tensor<*xi32>) -> tensor<*xi32> { - %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "\01\02\03", - sharding = "\01\02\03"} : (tensor<*xi32>) -> tensor<*xi32> - %1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>) - return %1 : tensor<*xi32> - } - ``` - - Now, cluster_func receives the following `*_sharding_configuration` - attributes, and `func` receives the mhlo.sharding attribute: - - ```mlir - func @test(%arg0: tensor<*xi32>) { - %0 = "tf_device.cluster_func"(%arg0) { - func = @func, - input_sharding_configuration = ["\01\02\03"], - output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], - step_marker_location = ""} : (tensor<*xi32>) -> tensor<*xi32> - return - } - func @func(%arg0: tensor<*xi32> {mhlo.sharding = "\01\02\03"}) -> - (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) { - %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "\01\02\03", sharding = "\01\02\03"} : (tensor<*xi32>) -> tensor<*xi32> - %1 = "tf.A"(%0) : (tensor<*xi32>) -> tensor<*xi32> - return %1 : tensor<*xi32> - } - ``` - }]; -} - def UnrollBatchMatMulPass : Pass<"tf-unroll-batch-matmul", "mlir::func::FuncOp"> { let summary = "Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops."; let constructor = "TF::CreateUnrollBatchMatMulPassPass()"; @@ -381,7 +331,9 @@ def TensorFlowShapeInferencePass : Pass<"tf-shape-inference", "ModuleOp"> { let options = [ Option<"max_iterations_", "max-iterations", "int64_t", /*default=*/"10", - "Maximum shape inference iterations"> + "Maximum shape inference iterations">, + Option<"input_arg_shapes_", "input-arg-shapes", "std::string", /*default=*/"", + "Input tensor shapes. Shapes for different tensors are separated by ':', and dimension sizes for the same tensor are separated by ','">, ]; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc index 0bc7b47377fa3f..40d9032b499ff6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -70,7 +71,7 @@ class AssetSinkingPass : public impl::AssetSinkingPassBase { SymbolTable symbol_table(module); for (auto initializer : init_op.getInitializers()) { auto func = symbol_table.lookup( - initializer.cast().getValue()); + mlir::cast(initializer).getValue()); RewriteFunction(symbol_table, func); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc index 141807309c4a9c..26bc9dae51057c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc @@ -54,7 +54,7 @@ constexpr StringRef kTfInputShapesAttr = "tf._input_shapes"; // Build and returns ElementsAttr which holds the data in 'tensor'. ElementsAttr GetTensorValueAsElementsAttr(const tensorflow::Tensor& tensor, OpBuilder builder) { - tensorflow::StatusOr tensor_attr_or = + absl::StatusOr tensor_attr_or = tensorflow::ConvertTensor(tensor, &builder); if (!tensor_attr_or.ok()) return nullptr; return tensor_attr_or.value(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc index 7f449520030876..68d50e54a1bce0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc @@ -18,12 +18,13 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -58,7 +59,7 @@ static mlir::LogicalResult FilterTfgSpecificArgResultAttributes( llvm::SmallVector &output_attrs) { for (auto it : llvm::zip( types, array_attr.template getAsRange())) { - if (std::get<0>(it).isa()) continue; + if (mlir::isa(std::get<0>(it))) continue; output_types.push_back(std::get<0>(it)); mlir::NamedAttrList list; @@ -80,7 +81,7 @@ static mlir::LogicalResult ReformatOpAttributes( mlir::tfg::TFGraphDialect::getDeviceAttrKey())) { tensorflow::DeviceNameUtils::ParsedName parsed_name; if (!tensorflow::DeviceNameUtils::ParseFullName( - attr.getValue().cast().getValue().str(), + mlir::cast(attr.getValue()).getValue().str(), &parsed_name)) return mlir::failure(); if (!parsed_name.has_type) { @@ -106,7 +107,7 @@ static mlir::LogicalResult ReformatOpAttributes( static void FilterOutBlockArgControlDep( ValueRange operands, llvm::SmallVectorImpl &filtered) { for (Value value : operands) - if (!value.isa()) filtered.push_back(value); + if (!mlir::isa(value)) filtered.push_back(value); } // Split the tfg.NextIteration into tf_executor::NextIterationSourceOp and @@ -114,7 +115,7 @@ static void FilterOutBlockArgControlDep( static void SplitNextIteration(Block &block) { // TODO(b/207144333): Supports callback for unregistered ops block.walk([&](Operation *op) { - if (!op->getName().getStringRef().equals("tfg.NextIteration")) return; + if (op->getName().getStringRef() != "tfg.NextIteration") return; mlir::OpBuilder builder(op); llvm::SmallVector new_operands; @@ -218,7 +219,7 @@ class ConvertGraphFuncOp : public OpConversionPattern { Block &block = graph_func.getBody().front(); for (auto iter = block.args_begin(), end_iter = block.args_end(); iter != end_iter; ++iter) { - if (!iter->getType().isa()) + if (!mlir::isa(iter->getType())) iter->replaceAllUsesWith(func.getBody().getArgument(idx++)); } @@ -412,9 +413,9 @@ class ConvertGeneralOp : public ConversionPattern { for (Value value : operands) { // Because of the property of graph region, the control operands may // not have been converted to tf_executor::ControlType. - if (value.getType().isa() || - value.getType().isa()) { - if (!value.isa()) + if (mlir::isa(value.getType()) || + mlir::isa(value.getType())) { + if (!mlir::isa(value)) island_control_operands.push_back(value); } else { inner_op_operands.push_back(value); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc index d1a244b7f2ec2a..b4a98605a34ac2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc @@ -53,7 +53,7 @@ class TPUAnnotateDynamicShapeInputsPass // Finds op that created a given value. If the value is a BlockArgument, this // returns the owner of the Block. Operation* GetOpOfValue(Value value) { - if (auto block_arg = value.dyn_cast()) + if (auto block_arg = mlir::dyn_cast(value)) return block_arg.getOwner()->getParentOp(); return value.getDefiningOp(); @@ -98,7 +98,7 @@ void TPUAnnotateDynamicShapeInputsPass::runOnOperation() { // Update the marked argument with dynamic shapes. for (int index : dynamic_shape_arg_index) { BlockArgument arg = func.getArgument(index); - auto inputType = arg.getType().dyn_cast(); + auto inputType = mlir::dyn_cast(arg.getType()); // Only rank 1 tensor is supported for now. if (!inputType || inputType.getRank() != 1) continue; auto shape = llvm::to_vector<4>(inputType.getShape()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc index a6f9d7d4c63f01..e2b9c62ee8e6bc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/UseDefLists.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" @@ -94,8 +95,8 @@ void PopulateDeviceForOpResults( op_to_update = op_to_update->getParentOp(); for (Value result : op_to_update->getResults()) { - if (result.getType().isa()) continue; - if (result.getType().isa()) break; + if (mlir::isa(result.getType())) continue; + if (mlir::isa(result.getType())) break; value_to_device.insert({result, device}); } @@ -118,8 +119,8 @@ llvm::StringRef FindDeviceFromOperands( llvm::StringRef new_device; const bool is_switch = llvm::isa(op); for (Value operand : op.getOperands()) { - if (operand.getType().isa()) continue; - if (operand.getType().isa()) break; + if (mlir::isa(operand.getType())) continue; + if (mlir::isa(operand.getType())) break; if (is_switch && llvm::isa_and_nonnull(operand.getDefiningOp())) @@ -230,7 +231,7 @@ void PropagateDevicesToResults( mlir::Builder builder(func.getOperation()); for (OpOperand& operand : fetch.getOperation()->getOpOperands()) { - if (operand.get().getType().isa()) break; + if (mlir::isa(operand.get().getType())) break; auto it = value_to_device.find(operand.get()); if (it != value_to_device.end()) { auto device_attr = func.getResultAttrOfType( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc index 04b488a38048fd..2281658efc5ed1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc @@ -86,7 +86,7 @@ bool IsSupportedInputOp( resource_alias_analysis.GetResourceAliases(resource_iterator); auto is_generator = [](Value val) { - if (val.isa()) return true; + if (mlir::isa(val)) return true; Operation* definition = val.getDefiningOp(); return definition->getNumOperands() == 0 && definition->getNumResults() == 1; @@ -99,7 +99,7 @@ bool IsSupportedInputOp( if (!is_generator(alias)) return true; StringAttr device; - if (auto arg = alias.dyn_cast()) { + if (auto arg = mlir::dyn_cast(alias)) { device = func.getArgAttrOfType(arg.getArgNumber(), kFuncDeviceAttr); } else { @@ -186,10 +186,8 @@ bool HandleReplicatedInputs( BuildCopyWithLayout(execute_launch, compile_launch, get_layout, entry.value().get(), &builder); - auto device_list = replicate.getDevices() - .value() - .get(execute_launch.getDevice()) - .cast(); + auto device_list = mlir::cast( + replicate.getDevices().value().get(execute_launch.getDevice())); copy_with_layout->setAttr(kDeviceAttr, device_list.getValue()[entry.index()]); @@ -225,7 +223,7 @@ void HandleCompileAndExecutes( for (const auto& input_and_idx : llvm::enumerate(execute.getArgs())) { Value input = input_and_idx.value(); const int64_t execute_arg_index = input_and_idx.index(); - if (auto block_arg = input.dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(input)) { // For a block argument, consider transforms only when it is a // replicated input (defining ops will be outside the replicate node). if (maybe_replicate != block_arg.getParentRegion()->getParentOp() || diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc index fdea45957eb7d8..b2a3b81f63a1a9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -41,7 +42,7 @@ bool HasOutsideCompilationAttribute(Operation* op) { // Finds op that created a given value. If the value is a BlockArgument, this // returns the owner of the Block. Operation* GetOpOfValue(Value value) { - if (auto block_arg = value.dyn_cast()) + if (auto block_arg = mlir::dyn_cast(value)) return block_arg.getOwner()->getParentOp(); return value.getDefiningOp(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc index a2232f9f33bf2a..08165fb1435ff2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -61,12 +62,12 @@ LogicalResult ReplacePartitionedOp(IntegerAttr num_cores_per_replica, T op) { } auto element_type = getElementTypeOrSelf(first_operand_type); - if (element_type.isa()) { + if (mlir::isa(element_type)) { first_operand_type = - element_type.cast().getSubtypes().front(); + mlir::cast(element_type).getSubtypes().front(); } - auto tensor_type = first_operand_type.dyn_cast_or_null(); + auto tensor_type = mlir::dyn_cast_or_null(first_operand_type); if (!(tensor_type && tensor_type.hasRank())) { return op->emitError() << "cannot convert op with unranked or non-tensor input type " diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc index fa18fc25ce9c67..5f708ce0ee1a74 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc @@ -75,10 +75,14 @@ ResourceValueAndSubtype GetResourceWriteResult( // Checks if resource is read by TPU cluster. bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func, Value resource) { - for (Operation* resource_user : resource.getUsers()) - if (auto read = dyn_cast(resource_user)) - for (Operation* read_user : read.getValue().getUsers()) + for (Operation* resource_user : resource.getUsers()) { + if (auto read = dyn_cast(resource_user)) { + for (Operation* read_user : read.getValue().getUsers()) { if (read_user == cluster_func) return true; + if (isa(read_user)) return true; + } + } + } return false; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc index c6ce1428bfb3e4..ef16273e9eea45 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc @@ -90,7 +90,7 @@ LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef new_shape) { auto transform_result_type = RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input)); cast_input.setType(transform_result_type); - auto block_arg = cast_input.dyn_cast(); + auto block_arg = mlir::dyn_cast(cast_input); auto cast_op_input = dyn_cast_or_null(cast_input.getDefiningOp()); while (block_arg || cast_op_input) { if (block_arg) { @@ -105,7 +105,7 @@ LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef new_shape) { RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input)); cast_input.setType(transform_result_type); // Update block arg and cast_op_input. - block_arg = cast_input.dyn_cast(); + block_arg = mlir::dyn_cast(cast_input); cast_op_input = dyn_cast_or_null(cast_input.getDefiningOp()); } } @@ -114,7 +114,7 @@ LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef new_shape) { // Handles padding before convolution for space to depth transform. LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) { - auto ranked_type = op.getInput().getType().dyn_cast(); + auto ranked_type = mlir::dyn_cast(op.getInput().getType()); if (!ranked_type) return failure(); auto pad_input_shape = ranked_type.getShape(); Location loc = op.getLoc(); @@ -164,7 +164,7 @@ void HandleConv2DStride(TF::Conv2DOp conv2d) { // Transforms input shape for the first convolution. void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) { auto input = conv2d.getInput(); - auto input_shape = input.getType().cast().getShape(); + auto input_shape = mlir::cast(input.getType()).getShape(); SmallVector transform_shape = { input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size, input_shape[3] * block_size * block_size}; @@ -228,7 +228,7 @@ void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) { OpBuilder builder(conv2d); builder.setInsertionPoint(conv2d); // Book keeping filter information. - auto filter_shape = filter.getType().cast().getShape(); + auto filter_shape = mlir::cast(filter.getType()).getShape(); int64_t height = filter_shape[0]; int64_t width = filter_shape[1]; int64_t channel = filter_shape[2]; @@ -422,7 +422,7 @@ bool HandleHostReplicatedInputs(int64_t index, } for (auto entry : llvm::enumerate(inputs)) { Value input = entry.value().get(); - auto ranked_type = input.getType().dyn_cast(); + auto ranked_type = mlir::dyn_cast(input.getType()); if (!ranked_type) return false; auto input_shape = ranked_type.getShape(); auto space_to_depth = @@ -442,7 +442,7 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size, llvm::SmallVector transform_input_indices; for (const auto& input : llvm::enumerate(cluster_func.getOperands())) { - if (auto block_arg = input.value().dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(input.value())) { if (block_arg.getArgNumber() != arg_num) continue; // For a block argument, consider transforms only when it is a replicated // input (defining ops will be outside the replicate node). @@ -462,7 +462,8 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size, continue; } if (!IsSupportedHostInputOp(input_op)) continue; - auto ranked_type = input.value().getType().dyn_cast(); + auto ranked_type = + mlir::dyn_cast(input.value().getType()); if (!ranked_type) continue; auto input_shape = ranked_type.getShape(); HandleHostInput(input.value(), input.index(), cluster_func, block_size, @@ -473,7 +474,7 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size, // Checks if input shape of convolution is good for space to depth transform. bool Conv2DInputShapeCanTransform(Value input) { - auto ranked_type = input.getType().dyn_cast(); + auto ranked_type = mlir::dyn_cast(input.getType()); if (!ranked_type) return false; auto input_shape = ranked_type.getShape(); int32_t batch_size = input_shape[0]; @@ -486,7 +487,7 @@ bool Conv2DInputShapeCanTransform(Value input) { // Get block argument id and number of users for the input arg. std::optional GetBlockArgNum(Value arg) { - if (auto block_arg = arg.dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(arg)) { if (!Conv2DInputShapeCanTransform(arg)) return std::nullopt; unsigned num_users = std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end()); @@ -540,9 +541,9 @@ std::optional GetConv2DInputArgNum(TF::Conv2DOp conv2d) { void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { // Check if input and filter type are RankedTensorType. auto input_tensor_type = - conv2d.getInput().getType().dyn_cast(); + mlir::dyn_cast(conv2d.getInput().getType()); auto filter_tensor_type = - conv2d.getFilter().getType().dyn_cast(); + mlir::dyn_cast(conv2d.getFilter().getType()); if (!input_tensor_type || !filter_tensor_type) return; // Book keeping filter shape for padding and backprop filter rewrite. auto filter_shape = filter_tensor_type.getShape(); @@ -550,7 +551,7 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { filter_shape.end()); // Handles input. auto conv2d_input = conv2d.getInput(); - if (auto block_arg = conv2d_input.dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(conv2d_input)) { // Change on device function type/shape. HandleFuncOp(block_arg.getOwner()->getParentOp()); } @@ -559,7 +560,7 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { // Rewrite pad_op before Convolutioin. if (failed(HandlePad(pad_op, filter_shape[0], block_size))) return; auto pad_input = pad_op.getInput(); - if (auto block_arg = pad_input.dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(pad_input)) { // Change on device function type/shape. HandleFuncOp(block_arg.getOwner()->getParentOp()); } @@ -573,7 +574,7 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { // Book keeping new filter shape for backprop filter rewrite. // Filter shape is defined in HandleConv2DFilter, thus it is RankedTensorType. filter_shape = - conv2d.getFilter().getType().cast().getShape(); + mlir::cast(conv2d.getFilter().getType()).getShape(); SmallVector new_filter_shape(filter_shape.begin(), filter_shape.end()); @@ -593,7 +594,7 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) { SmallVector strides(4, 1); for (int i = 0; i < 3; ++i) { - strides[i] = conv2d.getStrides()[i].cast().getInt(); + strides[i] = mlir::cast(conv2d.getStrides()[i]).getInt(); } // Space to depth only supports striding at spatial dimension. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc index 4dc9daa6c705ee..21f62e41383401 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc @@ -372,7 +372,8 @@ bool CheckOpsClusterIO(Operation* op, MetadataMap& metadata_map) { bool TypeMustBeNonXLA(const Type& type) { const Type elem = getElementTypeOrSelf(type); - return !elem.isa() && !tensorflow::TypeValidForXLA(type); + return !mlir::isa(elem) && + !tensorflow::TypeValidForXLA(type); } // Check if the op cannot be XLA compiled. If the op does not satisfy this @@ -539,6 +540,18 @@ bool IsValidMAXIMALSharding(Operation* op, MetadataMap& metadata_map) { return true; } +bool HasSingleCoreTpu(Operation* op) { + if (auto compilation_attr = + op->getAttrOfType(TF::kCompileDeviceTypeAttr)) { + if (compilation_attr.getValue().str() == TF::kTpuDevice) { + op->emitOpError( + "TF2XLA TPU bridge input check: found a single-core TPU graph"); + return true; + } + } + return false; +} + void TPUValidateInputsPass::runOnOperation() { ModuleOp module = getOperation(); bool success = true; @@ -563,10 +576,11 @@ void TPUValidateInputsPass::runOnOperation() { success &= IsValidMAXIMALSharding(op, metadata_map); success &= IsValidShardingTupleForArity(op); } + success &= !HasSingleCoreTpu(op); + if (!success) { + signalPassFailure(); + } }); - if (!success) { - signalPassFailure(); - } } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc index abdd1a83d516eb..ff8ac1ad7cacd1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc @@ -96,7 +96,7 @@ TF::ReshapeOp ConvertTFBatchMatMulOp::createReshapeOp( template std::vector ConvertTFBatchMatMulOp::sliceInput( Value value, int batch_size, Location loc, PatternRewriter& rewriter) { - RankedTensorType tensorType = value.getType().cast(); + RankedTensorType tensorType = mlir::cast(value.getType()); Type element_type = tensorType.getElementType(); int rank = tensorType.getShape().size(); @@ -150,17 +150,17 @@ LogicalResult ConvertTFBatchMatMulOp::matchAndRewrite( Value input_lhs = op.getX(); Value input_rhs = op.getY(); - if (!input_lhs.getType().isa()) { + if (!mlir::isa(input_lhs.getType())) { // LHS must be a ranked tensor type return failure(); } - if (!input_rhs.getType().isa()) { + if (!mlir::isa(input_rhs.getType())) { // RHS must be a ranked tensor type return failure(); } - auto lhs_type = input_lhs.getType().cast(); - auto rhs_type = input_rhs.getType().cast(); + auto lhs_type = mlir::cast(input_lhs.getType()); + auto rhs_type = mlir::cast(input_rhs.getType()); // Skip int8 x int8 => int32. if (lhs_type.getElementType().isInteger(8) && diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc index 20dcdb8b034c97..9237ff8d5b69dd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo // IWYU pragma: keep #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep @@ -156,8 +157,8 @@ LogicalResult SymbolizeCustomCallCalledIndex( return WalkResult::interrupt(); } - auto called_index_attr = backend_config.get(kCalledIndexAttrName) - .dyn_cast_or_null(); + auto called_index_attr = mlir::dyn_cast_or_null( + backend_config.get(kCalledIndexAttrName)); if (!called_index_attr) { op->emitOpError() << "is missing attribute '" << kCalledIndexAttrName << "'"; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_serialization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_serialization.cc index a75bf4c75d8033..6ab5da6bdb2e3c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_serialization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_serialization.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "stablehlo/api/PortableApi.h" // from @stablehlo #include "stablehlo/dialect/Serialization.h" // from @stablehlo #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep @@ -66,8 +67,8 @@ FailureOr DesymbolizeCustomCallCalledIndex(ModuleOp module) { << "'"; return WalkResult::interrupt(); } - auto called_func = backend_config.get(kCalledFuncAttrName) - .dyn_cast_or_null(); + auto called_func = mlir::dyn_cast_or_null( + backend_config.get(kCalledFuncAttrName)); if (!called_func) { op->emitOpError() << "is missing attribute '" << kCalledFuncAttrName << "'"; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc index 1992f43a951184..8ce264b47b57d4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc @@ -18,6 +18,7 @@ limitations under the License. #include +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -44,8 +45,7 @@ void MoveResourceArgsToEnd(func::FuncOp callee) { // Copy the resource-type parameters to the end. for (unsigned i = 0; i < num_params; ++i) { BlockArgument param = callee.getArgument(i); - if (getElementTypeOrSelf(param.getType()) - .template isa()) { + if (mlir::isa(getElementTypeOrSelf(param.getType()))) { removed_params.set(i); callee.getBody().addArgument(param.getType(), param.getLoc()); param.replaceAllUsesWith(callee.getArguments().back()); @@ -65,7 +65,7 @@ void RewriteCall(tf_device::ClusterFuncOp cluster_func_op, SymbolTable &symtab, llvm::SmallVector non_resource_args, resource_args; bool has_resources = false, in_order = true; for (const Value &arg : cluster_func_op.getOperands()) { - if (!getElementTypeOrSelf(arg.getType()).template isa()) { + if (!mlir::isa(getElementTypeOrSelf(arg.getType()))) { non_resource_args.push_back(arg); if (has_resources) in_order = false; } else { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_validate_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_validate_inputs.cc index 24ae9056866ad4..9267607e7e342a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_validate_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_validate_inputs.cc @@ -65,7 +65,7 @@ LogicalResult HasNoNestedEntryFunctions( // "tf_saved_model.initializer_type" attribute from the callee of the // inner calls if the problem ever arises. entry_func->emitError() - << "TF2XLA MLIR CPU/GPU phase 1 bridge expects no nested calls" + << "TF2XLA MLIR Non-replicated Phase 1 Bridge expects no nested calls" " of entry functions as they prevent graph traversal in some " "passes from " "working correctly"; @@ -75,15 +75,13 @@ LogicalResult HasNoNestedEntryFunctions( return success(); } -// MLIR CPU/GPU phase 1 pipeline assumes an entry function has single region and -// single block when handling top-level compilation markers. -LogicalResult HasSingleBlockEntryFunctions( - llvm::SmallVector &entry_funcs, SymbolTable &symtab) { +LogicalResult HasTopLevelCompilationMarker( + llvm::SmallVector &entry_funcs) { for (auto &entry_func : entry_funcs) { - if (!HasSingleBlock(entry_func)) { - entry_func->emitError() << "TF2XLA MLIR CPU/GPU MLIR phase 1 bridge " - "expects single region and single " - "block in an entry function."; + if (entry_func->hasAttr(mlir::TF::kCompileDeviceTypeAttr)) { + entry_func->emitError() << "TF2XLA MLIR Non-replicated Phase 1 Bridge " + "does not support top-level compilation " + "marker."; return failure(); } } @@ -102,7 +100,7 @@ void XlaValidateInputsPass::runOnOperation() { return signalPassFailure(); } - if (HasSingleBlockEntryFunctions(entry_funcs, symtab).failed()) { + if (HasTopLevelCompilationMarker(entry_funcs).failed()) { return signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/BUILD b/tensorflow/compiler/mlir/tensorflow/translate/BUILD index 59d7cfd7081106..f0280340dddf62 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/translate/BUILD @@ -107,53 +107,11 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:IR", - "@local_xla//xla:status_macros", - ], -) - -cc_library( - name = "translate_tf_dialect_op", - srcs = ["translate_tf_dialect_op.cc"], - deps = [ - ":export_tf_dialect_op", - "//tensorflow/compiler/mlir/tensorflow", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@llvm-project//mlir:TranslateLib", - "@local_tsl//tsl/platform:protobuf", - ], - alwayslink = 1, -) - -cc_library( - name = "mlir_roundtrip_pass", - srcs = ["mlir_roundtrip_pass.cc"], - hdrs = ["mlir_roundtrip_pass.h"], - deps = [ - ":export_graphdef", - ":import_model", - ":mlir_roundtrip_flags", - "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/core:core_cpu_lib", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", "@local_xla//xla:status_macros", ], ) -cc_library( - name = "mlir_roundtrip_pass_registration", - srcs = ["mlir_roundtrip_pass_registration.cc"], - deps = [ - ":mlir_roundtrip_pass", - ], - alwayslink = 1, -) - cc_library( name = "mlir_roundtrip_flags", srcs = ["mlir_roundtrip_flags.cc"], @@ -209,6 +167,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 523048cd7cd582..0d8a75e7f7de9d 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -42,6 +42,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -129,7 +130,7 @@ class Exporter { // function are added to the graph with special op names kArgOp and kRetOp. // Later on, this graph can be converted a function definition and added to // another graph. - static StatusOr> Convert( + static absl::StatusOr> Convert( const GraphExportConfig& configs, const Dialect* tf_dialect, const SymbolTable& symbol_table, FuncOp function, FunctionLibraryDefinition* flib_def, @@ -155,13 +156,12 @@ class Exporter { Status AddEdge(Operation* inst); - StatusOr> GetArgumentNode(BlockArgument arg, - unsigned index, - llvm::StringRef name); - StatusOr> GetReturnNode(FuncOp function, - Value operand, - unsigned index, - llvm::StringRef name); + absl::StatusOr> GetArgumentNode( + BlockArgument arg, unsigned index, llvm::StringRef name); + absl::StatusOr> GetReturnNode(FuncOp function, + Value operand, + unsigned index, + llvm::StringRef name); Status GetControlRetNodes(mlir::tf_executor::FetchOp fetch, absl::flat_hash_set* control_ret_nodes); // Adds one edge between src_node and dst_node. If it is not a control edge, @@ -192,7 +192,7 @@ std::string FindFunctionName(const GraphExportConfig& configs, FuncOp func) { return func.getName().str(); } -StatusOr> Exporter::GetArgumentNode( +absl::StatusOr> Exporter::GetArgumentNode( BlockArgument arg, unsigned index, llvm::StringRef name) { auto func = arg.getParentRegion()->getParentOfType(); @@ -205,9 +205,9 @@ StatusOr> Exporter::GetArgumentNode( node_def->set_op(FunctionLibraryDefinition::kArgOp); - mlir::TensorType arg_type = arg.getType().cast(); + mlir::TensorType arg_type = mlir::cast(arg.getType()); if (auto resource_type = - arg_type.getElementType().dyn_cast()) { + mlir::dyn_cast(arg_type.getElementType())) { llvm::ArrayRef subtypes = resource_type.getSubtypes(); if (!subtypes.empty()) { AttrValue handle_dtypes_attr; @@ -254,7 +254,7 @@ StatusOr> Exporter::GetArgumentNode( return node_def; } -StatusOr> Exporter::GetReturnNode( +absl::StatusOr> Exporter::GetReturnNode( FuncOp function, Value operand, unsigned index, llvm::StringRef name) { auto node_def = std::make_unique(); if (!name.empty()) @@ -266,7 +266,8 @@ StatusOr> Exporter::GetReturnNode( node_def->set_op(FunctionLibraryDefinition::kRetOp); DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType( - operand.getType().cast().getElementType(), &dtype)); + mlir::cast(operand.getType()).getElementType(), + &dtype)); AttrValue type_attr; type_attr.set_type(dtype); (*node_def->mutable_attr())["T"] = type_attr; @@ -290,7 +291,7 @@ StatusOr> Exporter::GetReturnNode( Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index) { - if (auto input_result = src.dyn_cast()) { + if (auto input_result = mlir::dyn_cast(src)) { auto* input_inst = GetIslandInnerOpOrSelf(input_result.getOwner()); // Replaces the input node with NextIteration sink if it is a NextIteration // source. @@ -302,23 +303,23 @@ Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node, auto node_it = nodes_.find(input_inst); TF_RET_CHECK(node_it != nodes_.end()) << "Use of OpResult encountered before def!"; - if (input_result.getType().isa()) { + if (mlir::isa(input_result.getType())) { graph_->AddControlEdge(node_it->second, dst_node, /*allow_duplicates=*/true); } else { graph_->AddEdge(node_it->second, input_result.getResultNumber(), dst_node, dst_index); } - return OkStatus(); + return absl::OkStatus(); } - auto input_arg = src.cast(); + auto input_arg = mlir::cast(src); auto input_node_it = args_.find(input_arg); TF_RET_CHECK(input_node_it != args_.end()) << "Use of BlockArgument encounted before def!"; // For argument, there is only one result output, so the index is always 0. graph_->AddEdge(input_node_it->second, 0, dst_node, dst_index); - return OkStatus(); + return absl::OkStatus(); } Status Exporter::AddEdge(Operation* inst) { @@ -327,13 +328,13 @@ Status Exporter::AddEdge(Operation* inst) { if (auto fetch = llvm::dyn_cast(inst)) { for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) { Value operand = operand_and_idx.value(); - if (operand.getType().isa()) break; + if (mlir::isa(operand.getType())) break; auto* dst_node = returns_[fetch][operand_and_idx.index()]; TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand, dst_node, 0)); } - return OkStatus(); + return absl::OkStatus(); } // For tf_executor.NextIteration.Sink, skip its token operand and add data and @@ -348,14 +349,14 @@ Status Exporter::AddEdge(Operation* inst) { TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(control_and_idx.value(), dst_node, control_and_idx.index() + 1)); - return OkStatus(); + return absl::OkStatus(); } // For tf_executor.NextIteration.Source, op can be skipped as it is assumed // there are no operands. if (llvm::isa(inst)) { assert(inst->getNumOperands() == 0); - return OkStatus(); + return absl::OkStatus(); } Operation* op = GetIslandInnerOpOrSelf(inst); @@ -377,7 +378,7 @@ Status Exporter::AddEdge(Operation* inst) { AddEdgeBetweenNodes(operand_and_idx.value(), dst_node, operand_and_idx.index() + operand_offset)); - return OkStatus(); + return absl::OkStatus(); } void Exporter::UseOriginalFunctionNames(NodeDef& node_def) { @@ -424,7 +425,7 @@ Status Exporter::AddInstructionNode(Operation* inst) { TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(std::move(*node_def))); DCHECK(node != nullptr); nodes_[inst] = node; - return OkStatus(); + return absl::OkStatus(); } bool IsEntryFunctionArg(BlockArgument arg) { @@ -438,7 +439,7 @@ Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index, TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name)); TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(std::move(*node_def))); args_[arg] = node; - return OkStatus(); + return absl::OkStatus(); } // Creates return nodes per operand of a FetchOp. If names is supplied, those @@ -447,7 +448,8 @@ Status Exporter::AddFetchNode(FuncOp function, mlir::tf_executor::FetchOp fetch, llvm::ArrayRef names) { auto& return_nodes = returns_[fetch]; for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) { - if (operand_and_idx.value().getType().isa()) + if (mlir::isa( + operand_and_idx.value().getType())) break; TF_ASSIGN_OR_RETURN( @@ -458,7 +460,7 @@ Status Exporter::AddFetchNode(FuncOp function, mlir::tf_executor::FetchOp fetch, TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(std::move(*node_def))); return_nodes.push_back(node); } - return OkStatus(); + return absl::OkStatus(); } // Collects control ret Nodes based on tf_executor.graph's associated @@ -467,7 +469,7 @@ Status Exporter::GetControlRetNodes( mlir::tf_executor::FetchOp fetch, absl::flat_hash_set* control_ret_nodes) { for (Value fetch_operand : fetch.getOperands()) { - if (fetch_operand.getType().isa()) { + if (mlir::isa(fetch_operand.getType())) { Operation* defining_op = GetIslandInnerOpOrSelf(fetch_operand.getDefiningOp()); auto node_it = nodes_.find(defining_op); @@ -475,7 +477,7 @@ Status Exporter::GetControlRetNodes( control_ret_nodes->insert(node_it->second); } } - return OkStatus(); + return absl::OkStatus(); } // After conversion from MLIR the input names are all blank which causes @@ -494,7 +496,7 @@ void FixupInputNamesFromEdges(Graph* graph) { } } } -StatusOr> Exporter::Convert( +absl::StatusOr> Exporter::Convert( const GraphExportConfig& configs, const Dialect* tf_dialect, const SymbolTable& symbol_table, FuncOp function, FunctionLibraryDefinition* flib_def, @@ -509,14 +511,16 @@ StatusOr> Exporter::Convert( auto dict_attr = function->getAttrOfType(kEntryFuncAttr); if (dict_attr) { - TF_RET_CHECK(dict_attr.get("inputs").isa()) + TF_RET_CHECK(mlir::isa(dict_attr.get("inputs"))) << "inputs missing in entry function attribute"; - TF_RET_CHECK(dict_attr.get("outputs").isa()) + TF_RET_CHECK(mlir::isa(dict_attr.get("outputs"))) << "outputs missing in entry function attribute"; - dict_attr.get("inputs").cast().getValue().split( - input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); - dict_attr.get("outputs").cast().getValue().split( - output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); + mlir::cast(dict_attr.get("inputs")) + .getValue() + .split(input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); + mlir::cast(dict_attr.get("outputs")) + .getValue() + .split(output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); } auto graph = std::make_unique(OpRegistry::Global()); @@ -582,7 +586,7 @@ StatusOr> Exporter::Convert( int index = it.index(); auto arg = it.value(); mlir::Type type = arg.getType(); - if (!type.isa()) { + if (!mlir::isa(type)) { return errors::InvalidArgument( "FuncOps arguments must have tensor types. Found ", mlir::debugString(type), " in function ", function.getName().str()); @@ -601,14 +605,14 @@ StatusOr> Exporter::Convert( // library rather than the all the functions exported so far. TF_RETURN_IF_ERROR(graph->mutable_flib_def()->AddLibrary(*flib_def)); } - return OkStatus(); + return absl::OkStatus(); }; // Adds nodes for operations. for (Operation& inst : graph_op.GetBody()) { for (auto type : inst.getResultTypes()) - if (!type.isa()) + if (!mlir::isa(type)) return errors::InvalidArgument( "Values must be of tensor type, TensorFlow control type, or " "TensorFlow token type. Found ", @@ -669,7 +673,7 @@ Status Exporter::ConvertLibFunction( llvm::SmallDenseSet& visited_functions) { // Return early if the function has already been exported. bool is_new_function = visited_functions.insert(function).second; - if (!is_new_function) return OkStatus(); + if (!is_new_function) return absl::OkStatus(); auto function_name = FindFunctionName(configs, function); @@ -780,7 +784,7 @@ Status Exporter::Convert(mlir::ModuleOp module, if (flib_def != nullptr) { TF_RETURN_IF_ERROR(flib_def->AddLibrary(temp_flib_def)); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -805,7 +809,7 @@ Status ConvertMlirToGraph(mlir::ModuleOp module, &control_ret_nodes); } -StatusOr> ConvertMlirToGraphdef( +absl::StatusOr> ConvertMlirToGraphdef( mlir::ModuleOp module, const GraphExportConfig& configs) { FunctionLibraryDefinition flib_def(OpRegistry::Global(), FunctionDefLibrary()); @@ -825,7 +829,7 @@ StatusOr> ConvertMlirToGraphdef( return graphdef; } -tsl::Status ConvertMlirFunctionToFunctionLibraryDef( +absl::Status ConvertMlirFunctionToFunctionLibraryDef( FuncOp func, const GraphExportConfig& configs, FunctionDef* function_def) { Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf"); FunctionLibraryDefinition flib_def(OpRegistry::Global(), @@ -844,7 +848,7 @@ tsl::Status ConvertMlirFunctionToFunctionLibraryDef( const FunctionDef* func_def = flib_def.Find(name); if (func_def != nullptr) { *function_def = *func_def; - return OkStatus(); + return absl::OkStatus(); } return absl::InvalidArgumentError( absl::StrCat("Function '", name, diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h index 562226e1c764bc..e5e62a3e05a330 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h @@ -30,7 +30,7 @@ limitations under the License. namespace tensorflow { // Given an MLIR module, returns a GraphDef. -tsl::StatusOr> ConvertMlirToGraphdef( +absl::StatusOr> ConvertMlirToGraphdef( mlir::ModuleOp module, const GraphExportConfig& configs); // Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition. diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc index 6ce83519a0fe6d..debe84b63bbd1c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" @@ -59,7 +60,7 @@ Status SetTypeAttribute(absl::string_view name, ContainerT types, assert(result.second && "cannot have multiple attributes with the same name"); (void)result; - return OkStatus(); + return absl::OkStatus(); } // Sets shape list attribute with the given `name` to the given `shapes`. If the @@ -97,7 +98,7 @@ Status GetUnregisteredAttrs( absl::flat_hash_set* attrs_to_ignore) { if (!op_reg_data) { // This is likely a function call node, so we should continue. - return OkStatus(); + return absl::OkStatus(); } // Collect all the registered attributes. @@ -114,7 +115,7 @@ Status GetUnregisteredAttrs( absl::string_view(attr.getName().data(), attr.getName().size())); } } - return OkStatus(); + return absl::OkStatus(); } // Collects all attribute names to ignore in an MLIR operation when exporting to @@ -183,7 +184,7 @@ Status PopulateDerivedAttributes(mlir::Operation* inst, llvm::StringRef name, auto values = inst->getResults(); auto begin = values.begin(); auto end = values.begin(); - while (end != values.end() && (*end).getType().isa()) + while (end != values.end() && mlir::isa((*end).getType())) end++; if (begin != end) { mlir::TF::ResultShapeRange output_shapes = { @@ -193,7 +194,7 @@ Status PopulateDerivedAttributes(mlir::Operation* inst, llvm::StringRef name, } } - return OkStatus(); + return absl::OkStatus(); } // A `Cast` with DstT == SrcT can be introduced in MLIR as a shape cast. But @@ -253,7 +254,7 @@ Status GetAttrValuesFromOperation( value.mutable_func()->set_name(""); (*attributes)[kShapeInferenceGraph] = value; } - return OkStatus(); + return absl::OkStatus(); } absl::StatusOr> ConvertTFDialectOpToNodeDef( diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h index ae87dad305a7e0..f15e741b247340 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h @@ -42,7 +42,7 @@ Status GetAttrValuesFromOperation( // ShapedType for the leading values with ShapedType in the results of the // nodes. Set it to true if the returned NodeDef will be executed by the linked // TF Eager runtime. -StatusOr> ConvertTFDialectOpToNodeDef( +absl::StatusOr> ConvertTFDialectOpToNodeDef( mlir::Operation* inst, llvm::StringRef name, bool ignore_unregistered_attrs); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 42b059cbd0a527..3e72550a88749a 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -68,6 +68,7 @@ limitations under the License. #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" @@ -528,7 +529,7 @@ Status UpdateLegacyFedInputNode(const GraphDef& graph_def, auto it = inputs.find(node_name); // Node is not an input. - if (it == inputs.end()) return OkStatus(); + if (it == inputs.end()) return absl::OkStatus(); if (HasNonPrimaryOutputInUse(graph_def, node_name)) { return errors::InvalidArgument( @@ -549,7 +550,7 @@ Status UpdateLegacyFedInputNode(const GraphDef& graph_def, node->clear_input(); AddNodeAttr("dtype", dtype, node); AddNodeAttr("shape", it->second.shape, node); - return OkStatus(); + return absl::OkStatus(); } // Preprocesses GraphDef before it can be converted to Graph by, @@ -575,7 +576,7 @@ Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) { } ::tensorflow::AddDefaultsToNodeDef(op_reg_data->op_def, &node_def); } - return OkStatus(); + return absl::OkStatus(); } // Mapping from node name to feed (index and ArrayInfo). Node name must outlive @@ -691,7 +692,7 @@ Status ImporterBase::ConvertDeferredFunctions() { } } - return OkStatus(); + return absl::OkStatus(); } Status ImporterBase::RemoveBackedges() { @@ -718,7 +719,7 @@ Status ImporterBase::RemoveBackedges() { GetReversePostOrder( *graph_, &ordered_nodes_, [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); }); - return OkStatus(); + return absl::OkStatus(); } Status CopyStackTraces(const Graph& from, Graph* to) { @@ -744,7 +745,7 @@ Status CopyStackTraces(const Graph& from, Graph* to) { } } - return OkStatus(); + return absl::OkStatus(); } absl::StatusOr> @@ -808,7 +809,7 @@ Status ImporterBase::GetInputOutputNodes( absl::StrCat("Graph does not contain node: ", name)); } nodes->insert(it->second); - return OkStatus(); + return absl::OkStatus(); }; // Remap feeds and fetches to newly created Placeholder nodes. @@ -835,7 +836,7 @@ Status ImporterBase::GetInputOutputNodes( for (const auto& control_output : specs_.control_outputs) TF_RETURN_IF_ERROR(add_node(control_output)); - return OkStatus(); + return absl::OkStatus(); } // TODO(jpienaar): Remove this post shape inference on import flag is removed. @@ -934,7 +935,7 @@ Status ImporterBase::AddNodesToShapeRefiner( << kOutputShapesAttrName << " attribute specifies shapes for " << list.shape_size() << " outputs"; - return OkStatus(); + return absl::OkStatus(); } for (const auto& shape : llvm::enumerate(list.shape())) { @@ -947,7 +948,7 @@ Status ImporterBase::AddNodesToShapeRefiner( } node_context->set_output(shape.index(), handle); } - return OkStatus(); + return absl::OkStatus(); }; // If it is the argument node, the shape handle is set explicitly, so it @@ -1069,7 +1070,7 @@ Status ImporterBase::AddNodesToShapeRefiner( } VLOG(1) << "Graph shapes were inferred with " << (i - 1) << " extra rounds of analysis to reach a fixpoint."; - return OkStatus(); + return absl::OkStatus(); } absl::StatusOr ImporterBase::InferInputType(const Node& node, @@ -1228,7 +1229,7 @@ absl::StatusOr ImporterBase::InferOutputType( TF_ASSIGN_OR_RETURN( auto etype, ConvertToMlirTensorType(shape_proto, dtype, &builder)); return mlir::UnrankedTensorType::get(mlir::TF::ResourceType::get( - {etype.cast()}, builder.getContext())); + {mlir::cast(etype)}, builder.getContext())); } else { return mlir::UnrankedTensorType::get( mlir::TF::ResourceType::get(builder.getContext())); @@ -1331,7 +1332,7 @@ Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name, NamedAttrList* attributes) { TF_ASSIGN_OR_RETURN(auto func_attr, ConvertFunctionCallName(value.func().name())); - if (!func_attr) return OkStatus(); + if (!func_attr) return absl::OkStatus(); attributes->push_back(builder_.getNamedAttr(base_name, func_attr)); for (const auto& it : value.func().attr()) { @@ -1339,7 +1340,7 @@ Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name, TF_ASSIGN_OR_RETURN(auto value, ConvertAttributeValue(it.second)); attributes->push_back(builder_.getNamedAttr(name, value)); } - return OkStatus(); + return absl::OkStatus(); } absl::StatusOr ImporterBase::ConvertFunctionCallName( @@ -1411,7 +1412,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { // done. if (tf_name_to_mlir_name_->find(std::string(func_name)) != tf_name_to_mlir_name_->end()) - return OkStatus(); + return absl::OkStatus(); std::string mlir_func_name( function_name_uniquifier_->GetUniqueName(func_name)); @@ -1458,7 +1459,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { } deferred_functions_.emplace(func_name.str(), attributes); - return OkStatus(); + return absl::OkStatus(); } Status ImporterBase::PruneUnreachableNodes( @@ -1475,7 +1476,7 @@ Status ImporterBase::PruneUnreachableNodes( } else { VLOG(1) << "No output nodes specified, skipping pruning"; } - return OkStatus(); + return absl::OkStatus(); } Status ImporterBase::ConvertFeedsToPlaceholders( @@ -1524,7 +1525,7 @@ Status ImporterBase::ConvertFeedsToPlaceholders( } } } - return OkStatus(); + return absl::OkStatus(); } Status ImporterBase::PrepareConvert(const Graph& graph, @@ -1568,7 +1569,7 @@ Status ImporterBase::PrepareConvert(const Graph& graph, [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); }); } - return OkStatus(); + return absl::OkStatus(); } Status ImporterBase::Convert( @@ -1622,7 +1623,7 @@ Status ImporterBase::Convert( } } - return OkStatus(); + return absl::OkStatus(); } Status ImporterBase::ConvertFunctionArgAndRets( @@ -1659,7 +1660,7 @@ Status ImporterBase::ConvertFunctionArgAndRets( ret_attrs[index].set(dialect_attribute, converted_attr); } } - return OkStatus(); + return absl::OkStatus(); }; auto* bb = &func.front(); @@ -1753,7 +1754,7 @@ Status ImporterBase::ConvertFunctionArgAndRets( return list.getDictionary(context_); }))); - return OkStatus(); + return absl::OkStatus(); } mlir::Location ImporterBase::GetLocation(const Node& node) { @@ -2000,7 +2001,7 @@ mlir::Operation* ImporterBase::CreateOperation( record_resource = [&](mlir::Type type) { type.walk([&](mlir::Type t) { if (resource) return mlir::WalkResult::interrupt(); - if (type.isa()) { + if (mlir::isa(type)) { resource = true; return mlir::WalkResult::interrupt(); } @@ -2035,7 +2036,7 @@ Status ImporterBase::ConvertNode(const Node& node) { if (!node.IsOp()) { // Don't import the pseudo-nodes _SOURCE or _SINK. These are added by // Graph and don't exist in GraphDef. - return OkStatus(); + return absl::OkStatus(); } // If it is a custom OP, its definition should be found in the library. We @@ -2223,7 +2224,7 @@ Status ImporterBase::ConvertNode(const Node& node) { // Register the mapping between the TF node and the newly created operation. node_values_[node.id()] = CreateOperation(node, node_type_name, result, control_operands); - return OkStatus(); + return absl::OkStatus(); } // Add the backedges to the CFG. Given a backedge, we replace the original @@ -2249,7 +2250,7 @@ Status ImporterBase::AddBackedges() { auto* dst = node_values_[edge.dst->id()]; TF_RETURN_IF_ERROR(AddBackedge(sink, dst, edge.dst_input)); } - return OkStatus(); + return absl::OkStatus(); } Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst, @@ -2285,7 +2286,7 @@ Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst, } dst->dropAllReferences(); dst->erase(); - return OkStatus(); + return absl::OkStatus(); } absl::StatusOr ImporterBase::InferLibFunctionType( @@ -2714,7 +2715,7 @@ GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph( (*nodes)[index].node->name(), "'"); (*nodes)[index] = {node, 0}; - return OkStatus(); + return absl::OkStatus(); }; // Collect arg and ret nodes from graph. @@ -2758,7 +2759,7 @@ GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph( Status GraphDefImporter::GetControlRetsFromGraph( llvm::ArrayRef control_outputs, absl::InlinedVector* control_ret_nodes) { - if (control_outputs.empty()) return OkStatus(); + if (control_outputs.empty()) return absl::OkStatus(); llvm::SmallDenseMap controls_to_idx; for (const auto& control_and_idx : llvm::enumerate(control_outputs)) @@ -2779,7 +2780,7 @@ Status GraphDefImporter::GetControlRetsFromGraph( return errors::InvalidArgument( "Control output '", std::get<1>(node_and_name), "' is missing"); - return OkStatus(); + return absl::OkStatus(); } // Stateful helper class to import a TensorFlow model expressed in SavedModel @@ -3059,7 +3060,7 @@ Status DiagnoseMultipleConcreteFunctions(const SavedObjectGraph& object_graph, } } } - return OkStatus(); + return absl::OkStatus(); } // Recursively traverses a StructuredValue, linearizing all the leaves. @@ -3187,10 +3188,10 @@ void StructuredValueLinearizer::RecursivelyFindLeaves( << " at index path: "; for (auto path_element : current_index_path_) { os << "."; - if (auto integer = path_element.dyn_cast()) { + if (auto integer = mlir::dyn_cast(path_element)) { os << integer.getValue(); } else { - auto str = path_element.cast(); + auto str = mlir::cast(path_element); os << str.getValue(); } } @@ -3357,7 +3358,7 @@ Status CreateSavedModelIR( const TrackableObjectGraph::TrackableObject& trackable_object) { restored_objects.insert( std::make_pair(saved_node_id, &trackable_object)); - return OkStatus(); + return absl::OkStatus(); })); for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) { @@ -3554,7 +3555,7 @@ Status CreateSavedModelIR( module->setAttr("tf_saved_model.semantics", builder.getUnitAttr()); SortSavedModelModule(module); MarkSavedModelFunctionVisibility(module); - return OkStatus(); + return absl::OkStatus(); } absl::StatusOr> @@ -3893,7 +3894,7 @@ Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule( symbol_table_.insert(func.clone()); } - return OkStatus(); + return absl::OkStatus(); } Status SavedModelSignatureDefImporterLite::ConvertInitializer( @@ -4271,7 +4272,7 @@ Status SavedModelSignatureDefImporter::LiftVariables( return diag_handler.Combine( errors::Internal("Failed to dedup bound inputs.")); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index 1670fd11a1f819..bca1f7f80af9e8 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -39,13 +39,13 @@ inline constexpr absl::string_view kImportModelDefaultGraphFuncName = "main"; // Given a GraphDef, returns a MLIR module containing the graph, expressed with // tf_executor dialect. -tsl::StatusOr> ConvertGraphdefToMlir( +absl::StatusOr> ConvertGraphdefToMlir( const GraphDef& graphdef, const GraphDebugInfo& debug_info, const GraphImportConfig& specs, mlir::MLIRContext* context); // Given a Graph, returns a MLIR module containing the graph, expressed with // tf_executor dialect. -tsl::StatusOr> ConvertGraphToMlir( +absl::StatusOr> ConvertGraphToMlir( const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, mlir::MLIRContext* context); @@ -53,19 +53,19 @@ tsl::StatusOr> ConvertGraphToMlir( // [Experimental] // Given a Function, returns a MLIR module containing the graph, expressed with // tf_executor dialect. -tsl::StatusOr> ConvertFunctionToMlir( +absl::StatusOr> ConvertFunctionToMlir( const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def, mlir::MLIRContext* context); // Given a SavedModel, returns a MLIR module containing the functions, expressed // with tf_executor dialect. -tsl::StatusOr> ConvertSavedModelToMlir( +absl::StatusOr> ConvertSavedModelToMlir( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, MLIRImportOptions options = {}); // Given a V1 SavedModel, returns a MLIR module containing the functions, // expressed with tf_executor dialect. -tsl::StatusOr> ConvertSavedModelV1ToMlir( +absl::StatusOr> ConvertSavedModelV1ToMlir( const SavedModelBundle& saved_model, absl::Span exported_names, mlir::MLIRContext* context, MLIRImportOptions options = {}); @@ -79,7 +79,7 @@ tsl::StatusOr> ConvertSavedModelV1ToMlir( // ConvertSavedModelV1ToMlir(), and is not related to TFLite. // // TODO(b/179683149): Rename this class to avoid confusion with TFLite. -tsl::StatusOr> ConvertSavedModelV1ToMlirLite( +absl::StatusOr> ConvertSavedModelV1ToMlirLite( const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info, std::optional> exported_names, mlir::MLIRContext* context, MLIRImportOptions options); @@ -112,8 +112,8 @@ class SavedModelMLIRImportInput { // and remain valid for the graph. // `name` is a unique identifier for this subgraph, so the implementation can // use it for eg. debugging or caching compilation results. - virtual tsl::StatusOr GetSubGraph(absl::string_view name, - GraphImportConfig& specs) = 0; + virtual absl::StatusOr GetSubGraph( + absl::string_view name, GraphImportConfig& specs) = 0; private: const MetaGraphDef* meta_graph_def_ = nullptr; @@ -131,7 +131,7 @@ class SavedModelMLIRImportInput { // ConvertSavedModelV1ToMlir(), and is not related to TFLite. // // TODO(b/179683149): Rename this class to avoid confusion with TFLite. -tsl::StatusOr> ConvertSavedModelV1ToMlirLite( +absl::StatusOr> ConvertSavedModelV1ToMlirLite( SavedModelMLIRImportInput& input, std::optional> exported_names, mlir::MLIRContext* context, diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc index 09115768652d32..dbccd07976a997 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc @@ -70,7 +70,7 @@ std::string GraphImportConfig::str() const { Status ParseOutputArrayInfo(absl::string_view array_names, std::vector* outputs) { TF_RETURN_IF_ERROR(ParseNodeNames(array_names, *outputs)); - return OkStatus(); + return absl::OkStatus(); } Status ParseOutputArrayInfo(const std::vector& output_names, @@ -79,7 +79,7 @@ Status ParseOutputArrayInfo(const std::vector& output_names, if (output_name.empty()) continue; outputs->push_back(output_name); } - return OkStatus(); + return absl::OkStatus(); } Status ParseInputArrayInfo(absl::string_view array_names, @@ -138,7 +138,7 @@ static Status HandleSubtype(absl::string_view subtype, subtype_tensor_shape.add_dim()->set_size(dim); } *result = {subtype_dtype, subtype_tensor_shape}; - return OkStatus(); + return absl::OkStatus(); } Status ParseInputArrayInfo( @@ -214,7 +214,7 @@ Status ParseInputArrayInfo( } } } - return OkStatus(); + return absl::OkStatus(); } Status ParseNodeShapes( @@ -232,13 +232,13 @@ Status ParseNodeShapes( shapes_vector.push_back(std::move(shape)); } } - return OkStatus(); + return absl::OkStatus(); } Status ParseNodeNames(absl::string_view names_str, std::vector& names_vector) { names_vector = absl::StrSplit(names_str, ',', absl::SkipEmpty()); - return OkStatus(); + return absl::OkStatus(); } static absl::StatusOr> ParseDTypesHelper( @@ -290,7 +290,7 @@ Status ParseNodeDataTypes(absl::string_view data_types_str, if (!data_types_str.empty()) { TF_ASSIGN_OR_RETURN(data_type_vector, ParseDTypesHelper(data_types_str)); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc deleted file mode 100644 index f0b415062f2d27..00000000000000 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h" - -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Verifier.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "xla/status_macros.h" -#include "tensorflow/core/common_runtime/graph_constructor.h" -#include "tensorflow/core/framework/graph_debug_info.pb.h" - -namespace tensorflow { - -using mlir::MLIRContext; - -static absl::StatusOr> Import( - const GraphOptimizationPassOptions& options, const Graph& graph, - MLIRContext* context) { - // TODO(fengliuai): get debug info at runtime. - GraphDebugInfo debug_info; - GraphImportConfig specs; - specs.enable_shape_inference = options.shape_inference_on_tfe_dialect_import; - - TF_ASSIGN_OR_RETURN( - auto module, - ConvertGraphToMlir(graph, debug_info, *options.flib_def, specs, context)); - mlir::StatusScopedDiagnosticHandler status_handler(context); - if (failed(mlir::verify(*module))) { - if (VLOG_IS_ON(1)) module->dump(); - return status_handler.ConsumeStatus(); - } - return module; -} - -static Status Export(mlir::OwningOpRef module, - const GraphOptimizationPassOptions& options, - std::unique_ptr* graph) { - GraphExportConfig confs; - return ConvertMlirToGraph(*module, confs, graph, options.flib_def); -} - -static Status Roundtrip(const GraphOptimizationPassOptions& options, - std::unique_ptr* graph, MLIRContext* context) { - TF_ASSIGN_OR_RETURN(auto module, Import(options, **graph, context)); - return Export(std::move(module), options, graph); -} - -Status MlirRoundtripPass::Run(const GraphOptimizationPassOptions& options) { - MLIRContext context; - if (options.graph) return Roundtrip(options, options.graph, &context); - - // If the graph is partitioned, then try and round trip them individually. - for (auto& it : *options.partition_graphs) { - VLOG(1) << "Roundtripping: " << it.first; - // TODO(jpienaar): Roundtrip results in different failures, investigate. - TF_RETURN_IF_ERROR(Import(options, *it.second, &context).status()); - } - return OkStatus(); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h deleted file mode 100644 index 81500cc9b78a76..00000000000000 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_ -#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_ - -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { - -// An optimization pass that simply roundtrips the Graph to MLIR and back. -class MlirRoundtripPass : public GraphOptimizationPass { - public: - Status Run(const GraphOptimizationPassOptions& options) override; -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index b0759da88e4ced..6eaa15e37d45e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" @@ -46,7 +47,7 @@ limitations under the License. namespace tensorflow { -static StatusOr> GraphdefToMlirImport( +static absl::StatusOr> GraphdefToMlirImport( llvm::StringRef input, const std::vector& input_arrays, const std::vector& input_dtypes, const std::vector>>& input_shapes, @@ -109,7 +110,8 @@ static StatusOr> GraphdefToMlirImport( context); } -StatusOr> GraphdefToMlirTranslateFunction( +absl::StatusOr> +GraphdefToMlirTranslateFunction( llvm::StringRef input, const std::vector& input_arrays, const std::vector& input_dtypes, const std::vector>>& input_shapes, @@ -125,7 +127,8 @@ StatusOr> GraphdefToMlirTranslateFunction( return module_or; } -StatusOr> GraphdefToMlirTranslateFunction( +absl::StatusOr> +GraphdefToMlirTranslateFunction( llvm::StringRef input, absl::string_view input_arrays, absl::string_view input_dtypes, absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view control_output_arrays, @@ -147,11 +150,12 @@ StatusOr> GraphdefToMlirTranslateFunction( context); } -StatusOr> SavedModelObjectGraphToMlirImport( - absl::string_view saved_model_dir, - const std::unordered_set& tags, - absl::Span exported_names, mlir::MLIRContext* context, - bool unconditionally_use_set_output_shapes) { +absl::StatusOr> +SavedModelObjectGraphToMlirImport(absl::string_view saved_model_dir, + const std::unordered_set& tags, + absl::Span exported_names, + mlir::MLIRContext* context, + bool unconditionally_use_set_output_shapes) { tensorflow::SavedModelV2Bundle bundle; auto load_status = tensorflow::SavedModelV2Bundle::Load( std::string(saved_model_dir.data(), saved_model_dir.length()), &bundle); @@ -174,7 +178,8 @@ StatusOr> SavedModelObjectGraphToMlirImport( return module_or; } -StatusOr> SavedModelSignatureDefsToMlirImport( +absl::StatusOr> +SavedModelSignatureDefsToMlirImport( absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context, @@ -210,15 +215,15 @@ StatusOr> SavedModelSignatureDefsToMlirImport( return module_or; } -StatusOr> +absl::StatusOr> SavedModelSignatureDefsToMlirImportLite( absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context, MLIRImportOptions options) { MetaGraphDef meta_graph_def; - auto status = ReadMetaGraphDefFromSavedModel(std::string(saved_model_dir), - tags, &meta_graph_def); + auto status = + ReadMetaGraphDefFromSavedModel(saved_model_dir, tags, &meta_graph_def); if (!status.ok()) { LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir << "': " << status; @@ -239,7 +244,7 @@ SavedModelSignatureDefsToMlirImportLite( return module_or; } -StatusOr> +absl::StatusOr> GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, const std::vector& input_arrays, const std::vector& input_dtypes, @@ -263,7 +268,7 @@ GraphdefToSplattedMlirTranslateFunction( if (auto attr = inst.getAttrOfType(attr_id)) { mlir::Attribute rand_val; mlir::Type element_type = attr.getShapedType().getElementType(); - if (element_type.isa()) { + if (mlir::isa(element_type)) { rand_val = mlir::IntegerAttr::get(element_type, std::rand()); } else if (element_type.isF16() || element_type.isF32() || element_type.isF64()) { @@ -286,7 +291,7 @@ GraphdefToSplattedMlirTranslateFunction( return module_or; } -StatusOr> +absl::StatusOr> GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, absl::string_view input_arrays, absl::string_view input_dtypes, absl::string_view input_shapes, diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index 3dd76e2c12e85e..cd86b27e13550c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -53,7 +53,8 @@ struct GraphdefToMlirOptions { // Converts a TensorFlow GraphDef contained in `input` param into a MLIR module. // Creates MLIR entities into the given MLIR `context`. -StatusOr> GraphdefToMlirTranslateFunction( +absl::StatusOr> +GraphdefToMlirTranslateFunction( llvm::StringRef input, const std::vector& input_arrays, const std::vector& input_dtypes, const std::vector>>& input_shapes, @@ -66,7 +67,8 @@ ABSL_DEPRECATED( "inputs instead of strings") // Converts a TensorFlow GraphDef contained in `input` param into a MLIR module. // Creates MLIR entities into the given MLIR `context`. -StatusOr> GraphdefToMlirTranslateFunction( +absl::StatusOr> +GraphdefToMlirTranslateFunction( llvm::StringRef input, absl::string_view input_arrays, absl::string_view input_dtypes, absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view control_output_arrays, @@ -74,7 +76,7 @@ StatusOr> GraphdefToMlirTranslateFunction( // Similar as the above function, but replaces all constant tensors // with randomly generated splat values. -StatusOr> +absl::StatusOr> GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, const std::vector& input_arrays, const std::vector& input_dtypes, @@ -88,7 +90,7 @@ ABSL_DEPRECATED( "inputs instead of strings") // Similar as the above function, but replaces all constant tensors // with randomly generated splat values. -StatusOr> +absl::StatusOr> GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, absl::string_view input_arrays, absl::string_view input_dtypes, absl::string_view input_shapes, @@ -98,7 +100,8 @@ GraphdefToSplattedMlirTranslateFunction( // Converts a TensorFlow SavedModel stored in the directory with the given // `saved_model_dir` into a MLIR module. Creates MLIR entities into the // given MLIR `context`. -StatusOr> SavedModelObjectGraphToMlirImport( +absl::StatusOr> +SavedModelObjectGraphToMlirImport( absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context, @@ -108,7 +111,8 @@ StatusOr> SavedModelObjectGraphToMlirImport( // `saved_model_dir` into a MLIR module. Creates MLIR entities into the // given MLIR `context`. // 'saved_model_bundle' if not null, will be initialized with the model bundle. -StatusOr> SavedModelSignatureDefsToMlirImport( +absl::StatusOr> +SavedModelSignatureDefsToMlirImport( absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context, @@ -120,7 +124,7 @@ StatusOr> SavedModelSignatureDefsToMlirImport( // `saved_model_dir` into a MLIR module. Creates MLIR entities into the // given MLIR `context`. This does not create session internally so it is faster // and does not perform any graph transformation. -StatusOr> +absl::StatusOr> SavedModelSignatureDefsToMlirImportLite( absl::string_view saved_model_dir, const std::unordered_set& tags, diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index eb9bf3db34106d..0357a2e5f22986 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -172,7 +172,7 @@ static LogicalResult MlirToGraphdefTranslateFunction( confs.export_entry_func_to_flib = export_entry_func_to_flib; confs.export_original_tf_func_name = export_original_tf_func_name; - StatusOr> graphdef_or( + absl::StatusOr> graphdef_or( tensorflow::ConvertMlirToGraphdef(module, confs)); if (!graphdef_or.status().ok()) { LOG(ERROR) << "Graph export failed: " << graphdef_or.status(); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc deleted file mode 100644 index 856db032e501ae..00000000000000 --- a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/ToolOutputFile.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" -#include "tsl/platform/protobuf.h" - -namespace mlir { -static mlir::Operation* ExtractOnlyOp(mlir::ModuleOp module) { - mlir::func::FuncOp fn = module.lookupSymbol("main"); - if (!fn) return nullptr; - - if (!llvm::hasSingleElement(fn)) return nullptr; - - // Here, modules with exactly two operations in the only basic block are - // supported. The last operation should be a terminator operation and the - // other operation is the operation of interest. - auto& block = fn.front(); - if (block.getOperations().size() != 2) return nullptr; - if (!block.back().hasTrait()) return nullptr; - - return &block.front(); -} - -static LogicalResult MlirToTfNodeDef(ModuleOp module, - llvm::raw_ostream& output) { - auto* context = module.getContext(); - - Operation* op = ExtractOnlyOp(module); - if (!op) { - emitError(UnknownLoc::get(context), - "modules with exactly one op other than terminator in a " - "'main' function's " - "only block are supported"); - return failure(); - } - - auto node_def_or = tensorflow::ConvertTFDialectOpToNodeDef( - op, "node_name", /*ignore_unregistered_attrs=*/false); - if (!node_def_or.ok()) { - op->emitError("failed to convert to TF NodeDef:") - << node_def_or.status().ToString(); - return failure(); - } - - output << tsl::LegacyUnredactedDebugString(*node_def_or.value()); - return success(); -} - -// Test only translation to convert a simple MLIR module with a single TF -// dialect op to NodeDef. -static TranslateFromMLIRRegistration translate_from_mlir_registration( - "test-only-mlir-to-tf-nodedef", "test-only-mlir-to-tf-nodedef", - MlirToTfNodeDef, mlir::RegisterAllTensorFlowDialects); - -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.cc index 098c7d19411979..45235b2931187c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" -#include #include #include #include @@ -34,16 +33,15 @@ using ::tensorflow::kValidDeviceTypes; LogicalResult HasValidCompilationAndReplicationAttributes(Operation& op) { auto replicate_attr = op.getAttrOfType(kReplicationInfoAttr); auto compile_attr = op.getAttrOfType(kCompileDeviceTypeAttr); - if (replicate_attr && !compile_attr) { + if (!replicate_attr && !compile_attr) return success(); + if (!replicate_attr || !compile_attr) + return op.emitOpError() << "is expected to have either both or none of '" + << kReplicationInfoAttr << "' and '" + << kCompileDeviceTypeAttr << "' attributes."; + if (replicate_attr.getValue().empty()) return op.emitOpError() - << "has '" << kReplicationInfoAttr << "' attribute but not '" - << kCompileDeviceTypeAttr << "' attribute which is unsupported"; - } - if (replicate_attr && replicate_attr.getValue().empty()) { - return op.emitOpError() - << "has an empty '" << kReplicationInfoAttr << "' attribute"; - } - if (compile_attr && failed(IsValidDeviceTypeOrEmpty(compile_attr))) { + << "has an empty '" << kReplicationInfoAttr << "' attribute."; + if (failed(IsValidDeviceTypeOrEmpty(compile_attr))) { return op.emitOpError() << "has invalid '" << kCompileDeviceTypeAttr << "' value '" << compile_attr.getValue() << "'"; } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h index 5a99806d4295f3..0771b529465a94 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/tf2xla/tf2xla_defs.h" namespace mlir { @@ -167,7 +168,7 @@ class IdentityNOp; // as an attribute. template bool GetValueAsConstant(Value val, AttrT &attr) { - while (auto result = val.dyn_cast()) { + while (auto result = mlir::dyn_cast(val)) { Operation *op = result.getOwner(); if (!isa(op) && !isa(op)) break; val = op->getOperand(result.getResultNumber()); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.cc index fd3c00a3873e5c..030b8ae7575a40 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" namespace mlir { @@ -54,7 +55,7 @@ llvm::SmallVector GetEntryFunctions(ModuleOp module) { LogicalResult GetCallees(SymbolUserOpInterface op, SymbolTable &symtab, llvm::SmallVector &callees) { for (auto attr : op->getAttrs()) { - auto sym = attr.getValue().dyn_cast(); + auto sym = mlir::dyn_cast(attr.getValue()); if (!sym) continue; auto callee = symtab.lookup(sym.getRootReference()); if (!callee) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc index 341749eddd0f63..9262f87edb46bd 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc @@ -32,7 +32,7 @@ namespace { constexpr StringRef kTestClusterName = "tpu0"; -tsl::StatusOr> GetMlirModuleFromString( +absl::StatusOr> GetMlirModuleFromString( StringRef string, MLIRContext* context) { DialectRegistry mlir_registry; RegisterAllTensorFlowDialects(mlir_registry); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc index fc0ee8b9d20691..5e320f3ab01a5e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc @@ -28,8 +28,8 @@ namespace tensorflow { // Converts non func AttrValue proto into an MLIR attribute. Func attribute is // exclused in this function because the function might be renamed when the // function definition is imported. -StatusOr ConvertNonFuncAttributeValue(const AttrValue& value, - mlir::Builder* builder) { +absl::StatusOr ConvertNonFuncAttributeValue( + const AttrValue& value, mlir::Builder* builder) { switch (value.value_case()) { case AttrValue::kI: return builder->getI64IntegerAttr(value.i()); @@ -90,8 +90,8 @@ StatusOr ConvertNonFuncAttributeValue(const AttrValue& value, } } -StatusOr ConvertAttributeValue(const AttrValue& value, - mlir::Builder* builder) { +absl::StatusOr ConvertAttributeValue(const AttrValue& value, + mlir::Builder* builder) { switch (value.value_case()) { case AttrValue::kFunc: { // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h index 18f732081de8ee..10271fcbd60f5c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h @@ -27,12 +27,12 @@ using tsl::StatusOr; // Converts non func AttrValue proto into an MLIR attribute. Func attribute is // exclused in this function because the function might be renamed when the // function definition is imported. -StatusOr ConvertNonFuncAttributeValue(const AttrValue& value, - mlir::Builder* builder); +absl::StatusOr ConvertNonFuncAttributeValue( + const AttrValue& value, mlir::Builder* builder); // Converts all kinds of AttrValue proto into an MLIR attribute. -StatusOr ConvertAttributeValue(const AttrValue& value, - mlir::Builder* builder); +absl::StatusOr ConvertAttributeValue(const AttrValue& value, + mlir::Builder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index 10e882192cfdf3..b9fef486428977 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -242,12 +243,12 @@ void ConvertToTensorShapeProto(ArrayRef shape, } PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) { - if (type.isa()) { + if (mlir::isa(type)) { // An empty PartialTensorShape indicates an unranked tensor. return PartialTensorShape(); } - if (auto tensor_type = type.dyn_cast()) { + if (auto tensor_type = mlir::dyn_cast(type)) { TensorShapeProto tensor_shape_proto; ConvertToTensorShapeProto(tensor_type.getShape(), &tensor_shape_proto); return PartialTensorShape(tensor_shape_proto); @@ -259,11 +260,11 @@ PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) { } mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) { - if (type.isa()) { + if (mlir::isa(type)) { return mlir::TF::ShapeAttr::get(type.getContext(), std::nullopt); } - if (auto tensor_type = type.dyn_cast()) { + if (auto tensor_type = mlir::dyn_cast(type)) { return mlir::TF::ShapeAttr::get(type.getContext(), tensor_type.getShape()); } @@ -427,10 +428,10 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { output->set_dtype(output_dtype); ConvertToTensorShapeProto(shape, output->mutable_tensor_shape()); - if (auto tensor_attr = attr.dyn_cast()) + if (auto tensor_attr = mlir::dyn_cast(attr)) return ConvertTensorProtoAttr(tensor_attr, output); - auto dense_attr = attr.dyn_cast(); + auto dense_attr = mlir::dyn_cast(attr); if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr"); switch (output_dtype) { @@ -496,7 +497,7 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { output->mutable_tensor_content()); break; case DT_STRING: - ConvertStringElementsAttr(dense_attr.cast(), + ConvertStringElementsAttr(mlir::cast(dense_attr), output->mutable_string_val()); break; case DT_UINT8: @@ -521,7 +522,7 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { return errors::Unimplemented(absl::StrCat("Unimplemented data type ", DataTypeString(output_dtype))); } - return OkStatus(); + return absl::OkStatus(); } Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) { @@ -530,7 +531,7 @@ Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) { if (!output_tensor->FromProto(tensor_proto)) { return InvalidArgument("Couldn't convert tensor proto to tensor."); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h index 227e4bf465f70b..92d6ee4bb65356 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h @@ -31,12 +31,12 @@ namespace tensorflow { using tsl::StatusOr; // Converts an TensorFlow tensor proto into an MLIR elements attribute. -StatusOr ConvertTensorProto(const TensorProto& input_tensor, - mlir::Builder* builder); +absl::StatusOr ConvertTensorProto( + const TensorProto& input_tensor, mlir::Builder* builder); // Converts an TensorFlow tensor into an MLIR elements attribute. -StatusOr ConvertTensor(const Tensor& input_tensor, - mlir::Builder* builder); +absl::StatusOr ConvertTensor(const Tensor& input_tensor, + mlir::Builder* builder); // Converts a shape from MLIR to a TensorFlow tensor shape proto. void ConvertToTensorShapeProto(llvm::ArrayRef shape, @@ -53,8 +53,8 @@ absl::StatusOr ConvertTypeToTensorSpecProto( const mlir::Type& type); // Converts a TensorFlow shape attribute to an MLIR shape attribute. -StatusOr ConvertTensorShapeProto(const TensorShapeProto& shape, - mlir::MLIRContext* context); +absl::StatusOr ConvertTensorShapeProto( + const TensorShapeProto& shape, mlir::MLIRContext* context); // Converts an MLIR elements attribute to a TensorFlow tensor proto. Status ConvertToTensorProto(mlir::ElementsAttr attr, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index f3c51f88fc7630..3feed8904fab0e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "xla/test.h" @@ -97,8 +98,8 @@ TEST(ConvertTypeToTensorTypeTest, ConvertStringTensor) { ASSERT_TRUE(value_or_status.ok()); auto attr = value_or_status.value(); - EXPECT_TRUE(attr.isa()); - auto string_attr = attr.cast(); + EXPECT_TRUE(mlir::isa(attr)); + auto string_attr = mlir::cast(attr); auto string_values = string_attr.getRawStringData(); ASSERT_EQ(string_values.size(), 4); EXPECT_EQ(string_values[0], mlir::StringRef("one")); @@ -191,7 +192,7 @@ TEST_F(ConvertTensorTest, Simple) { } bool IsSplat(mlir::ElementsAttr attr) { - return attr.cast().isSplat(); + return mlir::cast(attr).isSplat(); } TEST(ConvertTensorProtoTest, SplatTensor) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index 880501c3e89554..e3404d613c9f83 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/core/framework/types.h" @@ -38,61 +39,61 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) { switch (dtype) { case DT_HALF: *type = builder.getF16Type(); - return OkStatus(); + return absl::OkStatus(); case DT_FLOAT: *type = builder.getF32Type(); - return OkStatus(); + return absl::OkStatus(); case DT_DOUBLE: *type = builder.getF64Type(); - return OkStatus(); + return absl::OkStatus(); case DT_BOOL: *type = builder.getIntegerType(1); - return OkStatus(); + return absl::OkStatus(); case DT_INT8: *type = builder.getIntegerType(8); - return OkStatus(); + return absl::OkStatus(); case DT_INT16: *type = builder.getIntegerType(16); - return OkStatus(); + return absl::OkStatus(); case DT_INT32: *type = builder.getIntegerType(32); - return OkStatus(); + return absl::OkStatus(); case DT_INT64: *type = builder.getIntegerType(64); - return OkStatus(); + return absl::OkStatus(); case DT_UINT8: *type = builder.getIntegerType(8, /*isSigned=*/false); - return OkStatus(); + return absl::OkStatus(); case DT_UINT16: *type = builder.getIntegerType(16, /*isSigned=*/false); - return OkStatus(); + return absl::OkStatus(); case DT_UINT32: *type = builder.getIntegerType(32, /*isSigned=*/false); - return OkStatus(); + return absl::OkStatus(); case DT_UINT64: *type = builder.getIntegerType(64, /*isSigned=*/false); - return OkStatus(); + return absl::OkStatus(); case DT_BFLOAT16: *type = builder.getBF16Type(); - return OkStatus(); + return absl::OkStatus(); case DT_COMPLEX64: *type = mlir::ComplexType::get(builder.getF32Type()); - return OkStatus(); + return absl::OkStatus(); case DT_COMPLEX128: *type = mlir::ComplexType::get(builder.getF64Type()); - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_FLOAT8_E4M3FN: *type = builder.getFloat8E4M3FNType(); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); case tensorflow::DT_FLOAT8_E5M2: *type = builder.getFloat8E5M2Type(); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); case DT_INT4: *type = builder.getIntegerType(4, /*isSigned=*/true); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); case DT_UINT4: *type = builder.getIntegerType(4, /*isSigned=*/false); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); #define HANDLE_TF_TYPE(tftype, enumerant, name) \ case DT_##enumerant: \ *type = builder.getType(); \ @@ -108,54 +109,54 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) { Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { if (type.isF16()) { *dtype = DT_HALF; - return OkStatus(); + return absl::OkStatus(); } else if (type.isF32()) { *dtype = DT_FLOAT; - return OkStatus(); + return absl::OkStatus(); } else if (type.isF64()) { *dtype = DT_DOUBLE; - return OkStatus(); + return absl::OkStatus(); } else if (type.isBF16()) { *dtype = DT_BFLOAT16; - return OkStatus(); + return absl::OkStatus(); } else if (type.isFloat8E4M3FN()) { *dtype = DT_FLOAT8_E4M3FN; - return OkStatus(); + return absl::OkStatus(); } else if (type.isFloat8E5M2()) { *dtype = DT_FLOAT8_E5M2; - return OkStatus(); - } else if (auto itype = type.dyn_cast()) { + return absl::OkStatus(); + } else if (auto itype = mlir::dyn_cast(type)) { switch (itype.getWidth()) { case 1: *dtype = DT_BOOL; - return OkStatus(); + return absl::OkStatus(); case 4: *dtype = itype.isUnsigned() ? DT_UINT4 : DT_INT4; - return OkStatus(); + return absl::OkStatus(); case 8: *dtype = itype.isUnsigned() ? DT_UINT8 : DT_INT8; - return OkStatus(); + return absl::OkStatus(); case 16: *dtype = itype.isUnsigned() ? DT_UINT16 : DT_INT16; - return OkStatus(); + return absl::OkStatus(); case 32: *dtype = itype.isUnsigned() ? DT_UINT32 : DT_INT32; - return OkStatus(); + return absl::OkStatus(); case 64: *dtype = itype.isUnsigned() ? DT_UINT64 : DT_INT64; - return OkStatus(); + return absl::OkStatus(); default: return errors::Unimplemented( absl::StrCat("Converting ", debugString(type), " to DataType")); } - } else if (auto complex_type = type.dyn_cast()) { + } else if (auto complex_type = mlir::dyn_cast(type)) { auto etype = complex_type.getElementType(); if (etype.isF32()) { *dtype = DT_COMPLEX64; - return OkStatus(); + return absl::OkStatus(); } else if (etype.isF64()) { *dtype = DT_COMPLEX128; - return OkStatus(); + return absl::OkStatus(); } return errors::Unimplemented( absl::StrCat("Converting ", debugString(type), " to DataType")); @@ -174,13 +175,13 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { } Status ConvertToDataType(Type type, DataType* dtype) { - if (auto stype = type.dyn_cast()) { + if (auto stype = mlir::dyn_cast(type)) { TF_RETURN_IF_ERROR( ConvertScalarTypeToDataType(stype.getElementType(), dtype)); } else { TF_RETURN_IF_ERROR(ConvertScalarTypeToDataType(type, dtype)); } - return OkStatus(); + return absl::OkStatus(); } void ConvertToMlirShape(const TensorShape& input_shape, @@ -202,7 +203,7 @@ Status ConvertToMlirShape(const TensorShapeProto& input_shape, shape->push_back(d.size() == kTFDynamicSize ? ShapedType::kDynamic : d.size()); } - return OkStatus(); + return absl::OkStatus(); } absl::StatusOr ConvertToMlirTensorType( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h index 35a3d1fb156f2b..3c21aa260499c1 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h @@ -45,9 +45,8 @@ Status ConvertToMlirShape(const TensorShapeProto& input_shape, llvm::SmallVectorImpl* shape); // Given a tensor shape and dtype, get the corresponding MLIR tensor type. -StatusOr ConvertToMlirTensorType(const TensorShapeProto& shape, - DataType dtype, - mlir::Builder* builder); +absl::StatusOr ConvertToMlirTensorType( + const TensorShapeProto& shape, DataType dtype, mlir::Builder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc index 51db1be0820761..d9249d472b334c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" @@ -67,7 +68,7 @@ mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, for (const auto& kv : llvm::enumerate(array_attr)) { const int idx = kv.index(); - auto string_attr = kv.value().dyn_cast(); + auto string_attr = mlir::dyn_cast(kv.value()); if (!string_attr) return op->emitOpError(llvm::formatv( "bad '{0}' attribute at index {1}, not a string", kDevicesAttr, idx)); @@ -100,7 +101,7 @@ mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, llvm::formatv("bad '{0}' attribute, '{1}', not a valid device", kDevicesAttr, name.strref())); - if (auto gpu_metadata = attr.dyn_cast()) { + if (auto gpu_metadata = mlir::dyn_cast(attr)) { devices->AddGpuDevice(device, gpu_metadata); } else { devices->AddDevice(device); @@ -144,10 +145,11 @@ mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, auto devices_attr = op->getAttr(kDevicesAttr); if (!devices_attr) return mlir::success(); - if (auto array_attr = devices_attr.dyn_cast()) { + if (auto array_attr = mlir::dyn_cast(devices_attr)) { return GetDevicesFromOp(op, array_attr, devices); - } else if (auto dict_attr = devices_attr.dyn_cast()) { + } else if (auto dict_attr = + mlir::dyn_cast(devices_attr)) { return GetDevicesFromOp(op, dict_attr, devices); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc index 326dbbb4781602..f089ec111991e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" @@ -87,18 +88,18 @@ TEST(DeviceUtilTest, AddDeviceToOp) { ASSERT_EQ(devices_attr.size(), 3); // CPU device added with an empty metadata. - auto device_meta_0 = devices_attr.get(cpu0).dyn_cast(); + auto device_meta_0 = mlir::dyn_cast(devices_attr.get(cpu0)); ASSERT_NE(device_meta_0, nullptr); // GPU device successfully parsed compute capability from description. auto device_meta_1 = - devices_attr.get(gpu0).dyn_cast(); + mlir::dyn_cast(devices_attr.get(gpu0)); ASSERT_NE(device_meta_1, nullptr); ASSERT_EQ(device_meta_1.getCcMajor(), 7); ASSERT_EQ(device_meta_1.getCcMinor(), 0); // If description is empty GPU devices added with an empty metadata. - auto device_meta_2 = devices_attr.get(gpu1).dyn_cast(); + auto device_meta_2 = mlir::dyn_cast(devices_attr.get(gpu1)); ASSERT_NE(device_meta_2, nullptr); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc index 6a66067920fdcb..f0dd8f1c748a25 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/util/managed_stack_trace.h" @@ -33,7 +34,7 @@ StatusScopedDiagnosticHandler::StatusScopedDiagnosticHandler( this->shouldShowLocFn = [](Location loc) -> bool { // For a Location to be surfaced in the stack, it must evaluate to true. // For any Location that is a FileLineColLoc: - if (FileLineColLoc fileLoc = loc.dyn_cast()) { + if (FileLineColLoc fileLoc = mlir::dyn_cast(loc)) { return !tensorflow::IsInternalFrameForFilename( fileLoc.getFilename().str()); } else { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index f01a3f0e09d19b..96ba0afd096a16 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -81,22 +82,22 @@ Status ConvertLocation(mlir::Location inst_loc, llvm::StringRef node_name, NodeDef::ExperimentalDebugInfo* debug_info) { mlir::Location unwrapped_inst_loc = GetLocationWithoutOpType(inst_loc); - if (auto call_site = unwrapped_inst_loc.dyn_cast()) { - if (auto name_loc = GetLocationWithoutOpType(call_site.getCallee()) - .dyn_cast()) { + if (auto call_site = mlir::dyn_cast(unwrapped_inst_loc)) { + if (auto name_loc = mlir::dyn_cast( + GetLocationWithoutOpType(call_site.getCallee()))) { llvm::StringRef original_node_name, original_func_name; std::tie(original_node_name, original_func_name) = name_loc.getName().strref().split('@'); // The location points to the current node def. if (node_name == original_node_name && original_func_name.empty()) { - return OkStatus(); + return absl::OkStatus(); } debug_info->add_original_node_names(original_node_name.str()); if (!original_func_name.empty()) { debug_info->add_original_func_names(original_func_name.str()); } } - } else if (auto fused = unwrapped_inst_loc.dyn_cast()) { + } else if (auto fused = mlir::dyn_cast(unwrapped_inst_loc)) { auto locations = fused.getLocations(); if (locations.size() <= 1) return errors::InvalidArgument("expected experimental debuf info."); @@ -105,22 +106,22 @@ Status ConvertLocation(mlir::Location inst_loc, llvm::StringRef node_name, TF_RETURN_IF_ERROR(ConvertLocation(locations[i], node_name, debug_info)); } } - return OkStatus(); + return absl::OkStatus(); } Status ConvertAttribute(const mlir::BoolAttr& attr, AttrValue* value) { value->set_b(attr.getValue()); - return OkStatus(); + return absl::OkStatus(); } Status ConvertAttribute(const mlir::IntegerAttr& attr, AttrValue* value) { value->set_i(attr.getInt()); - return OkStatus(); + return absl::OkStatus(); } Status ConvertAttribute(const mlir::FloatAttr& attr, AttrValue* value) { value->set_f(attr.getValueAsDouble()); - return OkStatus(); + return absl::OkStatus(); } Status ConvertAttribute(const mlir::ElementsAttr& attr, AttrValue* value) { @@ -130,27 +131,27 @@ Status ConvertAttribute(const mlir::ElementsAttr& attr, AttrValue* value) { Status ConvertAttribute(const mlir::TF::PlaceholderAttr& attr, AttrValue* value) { value->set_placeholder(attr.getValue().str()); - return OkStatus(); + return absl::OkStatus(); } Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, AttrValue* value) { SetTensorShapeProto(attr, value->mutable_shape()); - return OkStatus(); + return absl::OkStatus(); } Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) { value->mutable_func()->set_name(attr.getValue().str()); - return OkStatus(); + return absl::OkStatus(); } Status ConvertAttribute(const mlir::TF::FuncAttr& attr, bool remove_ref_type, AttrValue* value) { - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.getName().cast(), value)); + TF_RETURN_IF_ERROR(ConvertAttribute( + mlir::cast(attr.getName()), value)); TF_RETURN_IF_ERROR(ConvertAttributes(attr.getAttrs().getValue(), /*attrs_to_ignore=*/{}, remove_ref_type, value->mutable_func()->mutable_attr())); - return OkStatus(); + return absl::OkStatus(); } Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) { @@ -158,22 +159,22 @@ Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) { switch (mangling_util::GetMangledKind(attr_value)) { case mangling_util::MangledKind::kUnknown: { value->set_s(std::string(attr_value)); - return OkStatus(); + return absl::OkStatus(); } case mangling_util::MangledKind::kDataType: { DataType dtype; TF_RETURN_IF_ERROR(mangling_util::DemangleDataType(attr_value, &dtype)); value->set_type(dtype); - return OkStatus(); + return absl::OkStatus(); } case mangling_util::MangledKind::kTensorShape: TF_RETURN_IF_ERROR( mangling_util::DemangleShape(attr_value, value->mutable_shape())); - return OkStatus(); + return absl::OkStatus(); default: return errors::Unimplemented("Mangled string couldn't be handled!"); } - return OkStatus(); + return absl::OkStatus(); } Status ConvertAttribute(mlir::Type type, bool remove_ref_type, @@ -182,7 +183,7 @@ Status ConvertAttribute(mlir::Type type, bool remove_ref_type, TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype)); if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype); value->set_type(dtype); - return OkStatus(); + return absl::OkStatus(); } Status ConvertAttribute(const mlir::TypeAttr& type, bool remove_ref_type, @@ -192,20 +193,20 @@ Status ConvertAttribute(const mlir::TypeAttr& type, bool remove_ref_type, Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) { value->clear_value(); - return OkStatus(); + return absl::OkStatus(); } Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type, AttrValue* value) { auto* list = value->mutable_list(); for (mlir::Attribute a : attr.getValue()) { - if (auto attr = a.dyn_cast()) { + if (auto attr = mlir::dyn_cast(a)) { list->add_b(attr.getValue()); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { list->add_i(attr.getInt()); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { list->add_f(attr.getValueAsDouble()); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { AttrValue nested_value; TF_RETURN_IF_ERROR(ConvertAttribute(attr, &nested_value)); switch (nested_value.value_case()) { @@ -221,32 +222,32 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type, default: return errors::Unimplemented("Unhandled nested attribute!"); } - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { TensorProto tensor; TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor)); *list->add_tensor() = tensor; - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { AttrValue attr_val; TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val)); *list->add_func() = attr_val.func(); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { AttrValue attr_val; // For type attributes, we only propagate the element type. mlir::Type elt_type = attr.getValue(); - if (auto shaped_type = elt_type.dyn_cast()) { + if (auto shaped_type = mlir::dyn_cast(elt_type)) { elt_type = shaped_type.getElementType(); } TF_RETURN_IF_ERROR( ConvertAttribute(elt_type, remove_ref_type, &attr_val)); list->add_type(attr_val.type()); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { AttrValue attr_val; TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val)); *list->add_shape() = attr_val.shape(); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { std::vector vals; for (mlir::Attribute a : attr.getValue()) { - auto i = a.dyn_cast(); + auto i = mlir::dyn_cast(a); if (!i) return errors::Unimplemented( "Expected 64-bit integer array attributes!"); @@ -263,7 +264,7 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type, return errors::Unimplemented("Unhandled attribute!"); } } - return OkStatus(); + return absl::OkStatus(); } // Returns true if the executor/control dialect op should map to Ref node in @@ -274,21 +275,21 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type, static bool IsRefTypeControlOp(mlir::Operation* op) { if (auto next_iter_sink = llvm::dyn_cast(op)) - return mlir::getElementTypeOrSelf(next_iter_sink.getInput().getType()) - .isa(); + return mlir::isa( + mlir::getElementTypeOrSelf(next_iter_sink.getInput().getType())); auto op_name_or_status = GetTensorFlowOpName(op->getName().getStringRef()); if (!op_name_or_status.ok()) return false; auto op_name = std::move(op_name_or_status).value(); - if (op_name.equals("NextIteration")) - return mlir::getElementTypeOrSelf(op->getOperand(0).getType()) - .isa(); - - if (op_name.equals("Enter") || op_name.equals("Exit") || - op_name.equals("Switch") || op_name.equals("Merge")) { - return getElementTypeOrSelf(op->getResult(0).getType()) - .isa(); + if (op_name == "NextIteration") + return mlir::isa( + mlir::getElementTypeOrSelf(op->getOperand(0).getType())); + + if (op_name == "Enter" || op_name == "Exit" || op_name == "Switch" || + op_name == "Merge") { + return mlir::isa( + getElementTypeOrSelf(op->getResult(0).getType())); } return false; } @@ -393,18 +394,18 @@ Status ConvertAttributes( name = mangling_util::DemangleAttributeName(name); } AttrValue value; - if (auto symbol_ref = attr.dyn_cast()) { - TF_RETURN_IF_ERROR( - ConvertAttribute(symbol_ref.cast(), &value)); + if (auto symbol_ref = mlir::dyn_cast(attr)) { + TF_RETURN_IF_ERROR(ConvertAttribute( + mlir::cast(symbol_ref), &value)); func_call_attrs[string(name)] = std::move(value); continue; } - if (auto func_attr = attr.dyn_cast()) { + if (auto func_attr = mlir::dyn_cast(attr)) { TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, remove_ref_type, &value)); func_call_attrs[string(name)] = std::move(value); continue; } - if (attr.isa()) { + if (mlir::isa(attr)) { // AffineMapAttr is not implemented. return errors::Unimplemented("AffineMap attribute (needed for '", name_strref, "') unimplemented"); @@ -444,7 +445,7 @@ Status ConvertAttributes( for (auto& it : func_call_attrs) { (*values)[it.first] = std::move(it.second); } - return OkStatus(); + return absl::OkStatus(); } Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type, @@ -467,7 +468,7 @@ Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type, actual_shape.ShortDebugString()); } } - return OkStatus(); + return absl::OkStatus(); } bool IsLegacyCallInstruction(mlir::Operation* inst) { @@ -476,7 +477,7 @@ bool IsLegacyCallInstruction(mlir::Operation* inst) { Status AddTensorFlowOpPrefix(std::string prefix) { GlobalOpPrefixes()->insert(prefix); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h index 86ff64b5ed4d0b..c12c2507e1a03c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -46,11 +46,11 @@ Status AddTensorFlowOpPrefix(std::string); // Maps an MLIR op name in the TensorFlow dialect or the TensorFlow control // dialect back into a TensorFlow valid op name. -StatusOr GetTensorFlowOpName(llvm::StringRef); +absl::StatusOr GetTensorFlowOpName(llvm::StringRef); // Converts an MLIR operation to TensorFlow NodeDef with given node name. This // name should be unique to the graph it is being inserted into. -StatusOr> GetOperationNodeDef( +absl::StatusOr> GetOperationNodeDef( mlir::Operation* inst, llvm::StringRef name); // Converts MLIR attributes with values to their tensorflow equivalent. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/location_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/location_utils.cc index 2a6ff2921a4ad5..afaa78640af3d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/location_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/location_utils.cc @@ -17,15 +17,16 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace tensorflow { mlir::Location GetLocationWithoutOpType(mlir::Location loc) { - if (auto fused_loc = loc.dyn_cast()) { + if (auto fused_loc = mlir::dyn_cast(loc)) { auto locations = fused_loc.getLocations(); if (!locations.empty()) { // Skip locations for propagating op_type metadata. - if (auto name_loc = locations[0].dyn_cast()) { + if (auto name_loc = mlir::dyn_cast(locations[0])) { if (name_loc.getName().strref().ends_with(":")) { if (locations.size() == 2) return locations[1]; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc index 2895ebdc9c6424..9e8db314f51b0d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/utils/string_container_utils.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/device.h" @@ -32,7 +33,7 @@ std::string GetVariableName(TF::VarHandleOp var_handle_op) { // In some cases the shared_name attribute doesn't have the same // tensor name in the model, so we first try to use the location // then fallback to shared_name attribute. - if (auto loc = var_handle_op->getLoc().dyn_cast()) + if (auto loc = mlir::dyn_cast(var_handle_op->getLoc())) return loc.getName().str(); return var_handle_op.getSharedName().str(); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc b/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc index 549b665f044314..6ab4aa64a89070 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc @@ -17,6 +17,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir { @@ -63,7 +64,7 @@ FailureOr GetTfFuncCustomCallFuncName( return failure(); } - if (auto attr = f.dyn_cast()) { + if (auto attr = mlir::dyn_cast(f)) { return attr; } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc index 97f1093fe3d56b..5a29bae67afe01 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo @@ -396,11 +397,11 @@ SerializedMlirStringAttrToMlirModuleTranslate(llvm::StringRef input, // an output parameter is provided for returning the number of chars read. size_t numRead; mlir::Attribute attr = mlir::parseAttribute(input, context, {}, &numRead); - if (!attr || !attr.isa()) { + if (!attr || !mlir::isa(attr)) { LOG(ERROR) << "Input is not parsable as a MLIR StringAttr."; return nullptr; } - auto str_attr = attr.cast(); + auto str_attr = mlir::cast(attr); mlir::DialectRegistry registry; RegisterMlirInputDialects(registry); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index c6ff5f5c93c6ef..d2f10367d0085b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -92,7 +93,7 @@ absl::Status MismatchedTPUSystemAttributeErr(absl::string_view attribute, T a, // found, the first one lexicographically is returned. If no TPU_SYSTEM device // is found or if there are multiple TPU_SYSTEM devices with different jobs or // replicas, a failure will be returned. -StatusOr> GetTPUSystemDevices( +absl::StatusOr> GetTPUSystemDevices( ParsedDevices devices) { ParsedDevice spec; spec.type = kDeviceTPUSystem; @@ -131,7 +132,7 @@ StatusOr> GetTPUSystemDevices( // Find TPU devices associated to system device based on spec (e.g. from // GetTPUSystemDevices). If the number of TPU devices per host do not match for // every host, a failure will be returned. -StatusOr, 8>> +absl::StatusOr, 8>> GetTPUDevices(ParsedDevices devices, llvm::ArrayRef system_devices) { llvm::SmallVector, 8> tpu_devices; @@ -192,8 +193,8 @@ std::string GetTPUCompilationDevice(ParsedDevice system_device) { // Find the host CPU device for a given TPU device with `DEVICE_CPU` as its // type. If multiple local cpu devices are disabled, always assign id 0. If // set, use the same id as the tpu device. -StatusOr GetCPUHostDeviceForTPUDevice(ParsedDevice tpu_device, - ParsedDevices devices) { +absl::StatusOr GetCPUHostDeviceForTPUDevice( + ParsedDevice tpu_device, ParsedDevices devices) { tpu_device.type = DEVICE_CPU; bool enable_multiple_local_cpu_devices = tensorflow::GetMlirCommonFlags() @@ -214,7 +215,7 @@ StatusOr GetCPUHostDeviceForTPUDevice(ParsedDevice tpu_device, // to every core in the mesh. TPU devices are simply added to // `execution_devices` of one replica. `num_replicas` must be 1 or the total // number of TPU devices available, and `num_cores_per_replica` must be 1. -StatusOr GetFullMeshTPUExecutionDeviceAssignment( +absl::StatusOr GetFullMeshTPUExecutionDeviceAssignment( int num_replicas, int num_cores_per_replica, llvm::ArrayRef> tpu_devices, ParsedDevices devices) { @@ -293,7 +294,7 @@ absl::Status DuplicateCoordinateErrorMsg(absl::string_view attribute, int x, // - device coordinates within the mesh shape // - no duplicate device coordinates // - number of device coordinates (in tuple 3) match number of availabe TPUs -StatusOr> ParseTopologyAttr( +absl::StatusOr> ParseTopologyAttr( llvm::StringRef topology_attr, int num_tasks, int num_tpus_per_task) { tpu::TopologyProto topology_proto; if (!topology_proto.ParseFromString(topology_attr.str())) @@ -375,7 +376,7 @@ StatusOr> ParseTopologyAttr( // - number of device coordinates (in tuple 3) match number 'num_replicas' * // 'num_cores_per_replica' // - a TPU device associated with each device coordinate -StatusOr> +absl::StatusOr> GetGeneralTPUExecutionDeviceAssignment( int num_replicas, int num_cores_per_replica, llvm::ArrayRef> tpu_devices, @@ -480,7 +481,7 @@ mlir::LogicalResult GetDeviceAssignmentCoordinates( return cluster.emitOpError(llvm::formatv("requires attribute '{0}'", tensorflow::kDeviceAssignmentAttr) .str()); - if (StatusOr> fetched_device_coordinates = + if (absl::StatusOr> fetched_device_coordinates = tensorflow::GetDeviceCoordinates(device_assignment_attr); fetched_device_coordinates.ok()) { device_coordinates = *fetched_device_coordinates; @@ -516,7 +517,7 @@ mlir::LogicalResult GetTPUDevicesAndHostsNotReplicated( } // Determine compilation and execution devices. - if (StatusOr tpu_device_assignment = + if (absl::StatusOr tpu_device_assignment = tensorflow::GetTPUCompilationAndExecutionDevices( devices.device_names(), /*num_replicas=*/1, GetNumCoresPerReplica(cluster), topology, device_coordinates); @@ -601,7 +602,7 @@ mlir::LogicalResult GetTPUToHostMap( } // anonymous namespace -StatusOr> GetDeviceCoordinates( +absl::StatusOr> GetDeviceCoordinates( mlir::ArrayAttr device_assignment_attr) { llvm::SmallVector device_coordinates; device_coordinates.reserve(device_assignment_attr.size()); @@ -609,7 +610,7 @@ StatusOr> GetDeviceCoordinates( for (auto device_coordinate_and_idx : llvm::enumerate(device_assignment_attr)) { auto device_coordinate = - device_coordinate_and_idx.value().dyn_cast(); + mlir::dyn_cast(device_coordinate_and_idx.value()); if (!device_coordinate) return absl::InvalidArgumentError( llvm::formatv(kBadIntArrayElementMsg, kDeviceAssignmentAttr, @@ -622,7 +623,7 @@ StatusOr> GetDeviceCoordinates( return device_coordinates; } -StatusOr GetTPUCompilationAndExecutionDevices( +absl::StatusOr GetTPUCompilationAndExecutionDevices( ParsedDevices devices, int num_replicas, int num_cores_per_replica, llvm::StringRef topology_attr, llvm::ArrayRef device_assignment_attr) { @@ -733,8 +734,8 @@ bool IsTPUReplicatedCore(llvm::StringRef device) { bool TypeValidForXLA(const mlir::Type& type) { const mlir::Type elem = getElementTypeOrSelf(type); - return !elem.isa() && - !elem.isa(); + return !mlir::isa(elem) && + !mlir::isa(elem); } mlir::LogicalResult GetDeviceToHostMap( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h index 9ed5d7614aaf4c..f7c9b29d6cfdcc 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h @@ -80,7 +80,7 @@ struct TPUDeviceAssignment { }; // Extracts device coordinates from a device assignment attribute on an op. -StatusOr> GetDeviceCoordinates( +absl::StatusOr> GetDeviceCoordinates( mlir::ArrayAttr device_assignment_attr); // Finds the TPU compilation device and execution devices from `devices` for a @@ -234,7 +234,7 @@ StatusOr> GetDeviceCoordinates( // replica_device_ids: 7 // } // } -StatusOr GetTPUCompilationAndExecutionDevices( +absl::StatusOr GetTPUCompilationAndExecutionDevices( llvm::ArrayRef devices, int num_replicas, int num_cores_per_replica, llvm::StringRef topology_attr, llvm::ArrayRef device_assignment_attr); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index 2c749b549cdc86..c6d80802b2aa0a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -38,7 +38,7 @@ limitations under the License. namespace tensorflow { namespace { -tsl::StatusOr> GetMlirModuleFromString( +absl::StatusOr> GetMlirModuleFromString( llvm::StringRef string, mlir::MLIRContext* context) { mlir::DialectRegistry mlir_registry; RegisterAllTensorFlowDialects(mlir_registry); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc index 988950389edf8b..e23e2313711f9c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" #include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/platform/errors.h" namespace tensorflow { @@ -44,21 +45,21 @@ mlir::LogicalResult ExtractTfVersions(mlir::ModuleOp module, if (!version_attr) return mlir::failure(); auto producer = - version_attr.get("producer").dyn_cast_or_null(); + mlir::dyn_cast_or_null(version_attr.get("producer")); if (!producer) return mlir::failure(); versions->set_producer(producer.getInt()); - auto min_consumer = - version_attr.get("min_consumer").dyn_cast_or_null(); + auto min_consumer = mlir::dyn_cast_or_null( + version_attr.get("min_consumer")); if (min_consumer) versions->set_min_consumer(min_consumer.getInt()); - auto bad_consumers = - version_attr.get("bad_consumers").dyn_cast_or_null(); + auto bad_consumers = mlir::dyn_cast_or_null( + version_attr.get("bad_consumers")); if (!bad_consumers) return mlir::success(); for (auto bad_consumer : bad_consumers) { auto bad_consumer_int_attr = - bad_consumer.dyn_cast_or_null(); + mlir::dyn_cast_or_null(bad_consumer); if (!bad_consumer_int_attr) return mlir::failure(); versions->mutable_bad_consumers()->Add(bad_consumer_int_attr.getInt()); @@ -66,13 +67,13 @@ mlir::LogicalResult ExtractTfVersions(mlir::ModuleOp module, return mlir::success(); } -::tsl::StatusOr GetTfGraphProducerVersion(mlir::ModuleOp module) { +absl::StatusOr GetTfGraphProducerVersion(mlir::ModuleOp module) { auto versions = module->getAttrOfType<::mlir::DictionaryAttr>("tf.versions"); if (!versions) { return errors::Internal( "Missing 'tf.versions' attribute on the module, abort.\n"); } - auto producer = versions.get("producer").dyn_cast(); + auto producer = mlir::dyn_cast(versions.get("producer")); if (!producer) { return errors::Internal( "Missing 'producer' attribute on the module, abort.\n"); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h index feccb1754d5781..f9acbb9a88e7cb 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h @@ -37,7 +37,7 @@ mlir::LogicalResult ExtractTfVersions(mlir::ModuleOp module, // Returns TensorFlow GraphDef producer version for the given module. Returns an // error if the version information is missing for the module or is not valid. -::tsl::StatusOr GetTfGraphProducerVersion(mlir::ModuleOp module); +absl::StatusOr GetTfGraphProducerVersion(mlir::ModuleOp module); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util_test.cc index eb4c9ee85274ea..c0046f83664223 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util_test.cc @@ -36,7 +36,7 @@ limitations under the License. namespace tensorflow { namespace { -tsl::StatusOr> GetMlirModuleFromString( +absl::StatusOr> GetMlirModuleFromString( llvm::StringRef string, mlir::MLIRContext* context) { mlir::DialectRegistry mlir_registry; RegisterAllTensorFlowDialects(mlir_registry); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index ea76adb284b7e2..334cca591cf569 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -70,7 +70,7 @@ mlir::LogicalResult CreateSplitOp(const int num_split, // Correctly set output shapes of split op output if input shape is statically // known. mlir::Type output_type; - auto input_type = src_input.getType().cast(); + auto input_type = mlir::cast(src_input.getType()); if (input_type.hasRank()) { if (input_type.getShape()[split_dimension] == mlir::ShapedType::kDynamic) { @@ -122,7 +122,7 @@ mlir::TF::ConcatOp CreateConcatOp(const int concat_dimension, // across logical devices, we refer to the shape of 0th logical device // computation output. mlir::Type output_type; - auto input_type = inputs[0].getType().cast(); + auto input_type = mlir::cast(inputs[0].getType()); if (input_type.hasRank()) { if (input_type.getShape()[concat_dimension] == mlir::ShapedType::kDynamic) { @@ -294,9 +294,9 @@ mlir::LogicalResult DecodeShardingAttribute(const std::string& shard_str, mlir::LogicalResult DecodeShardingAttribute(mlir::Attribute shard_attr, xla::OpSharding& sharding, bool report_error) { - if (!shard_attr.isa()) return mlir::failure(); + if (!mlir::isa(shard_attr)) return mlir::failure(); - auto shard_str = shard_attr.cast().getValue().str(); + auto shard_str = mlir::cast(shard_attr).getValue().str(); return DecodeShardingAttribute(shard_str, sharding, report_error); } @@ -350,7 +350,8 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( xla::OpSharding sharding; if (DecodeShardingAttribute( - sharding_attr.cast().getValue().str(), sharding) + mlir::cast(sharding_attr).getValue().str(), + sharding) .failed()) { return cluster_func.emitError("incorrect sharding format for inputs"); } @@ -443,13 +444,14 @@ mlir::LogicalResult ParseAndValidateOutputSharding( llvm::enumerate(output_sharding_attrs)) { const auto& output_sharding = output_sharding_and_index.value(); const int sharding_index = output_sharding_and_index.index(); - if (!output_sharding.isa()) + if (!mlir::isa(output_sharding)) return cluster_func.emitError(llvm::formatv( "non-string output sharding at index {0}", sharding_index)); xla::OpSharding sharding; if (DecodeShardingAttribute( - output_sharding.cast().getValue().str(), sharding) + mlir::cast(output_sharding).getValue().str(), + sharding) .failed()) { return cluster_func.emitError("incorrect sharding format for outputs"); } @@ -661,7 +663,7 @@ mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( const auto output_index = result_and_index.index(); const auto& output_sharding = output_sharding_config[output_index]; const auto cluster_func_output_type = - result_and_index.value().getType().cast(); + mlir::cast(result_and_index.value().getType()); // If output shape of cluster func is statically known and output is tiled // sharded, then the corresponding output shape of cluster func must be diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc index 322862828e63b3..7358b97971e0fe 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc @@ -175,13 +175,13 @@ Status GetXlaInputShapes( // bounded type by using the bounds as dimension sizes. Returns null if is // neither. mlir::RankedTensorType GetBufferType(mlir::Type ty) { - auto ranked_ty = ty.dyn_cast_or_null(); + auto ranked_ty = mlir::dyn_cast_or_null(ty); if (!ranked_ty) return {}; int64_t rank = ranked_ty.getRank(); llvm::SmallVector dims = llvm::to_vector<4>(ranked_ty.getShape()); - auto encoding = ranked_ty.getEncoding() - .dyn_cast_or_null(); + auto encoding = mlir::dyn_cast_or_null( + ranked_ty.getEncoding()); if (encoding && !encoding.getBounds().empty()) { for (int64_t dim = 0; dim < rank; ++dim) { if (dims[dim] == mlir::ShapedType::kDynamic) { @@ -234,7 +234,7 @@ Status GetOutputInfo( auto return_op = main_func.begin()->getTerminator(); for (const auto& type_and_idx : llvm::enumerate(func_type.getResults())) { size_t idx = type_and_idx.index(); - auto result_ty = type_and_idx.value().cast(); + auto result_ty = mlir::cast(type_and_idx.value()); // If the result type isn't static, then the owner of the result may be a // cast op from a more specific bounded type to an unbounded dynamic type. @@ -275,7 +275,8 @@ Status GetOutputInfo( TF_RETURN_IF_ERROR(MaybeRewriteLayoutWithShardedShape( sharding, shape_determination_fns, &shape)); - auto tensor_type = type_and_idx.value().dyn_cast(); + auto tensor_type = + mlir::dyn_cast(type_and_idx.value()); shapes.push_back(shape); auto it = output_to_input_alias.find(type_and_idx.index()); @@ -872,7 +873,7 @@ static absl::StatusOr> RewriteWithArgs( auto resource_type = mlir::TF::ResourceType::get({resource_subtype}, builder.getContext()); - auto tensor_type = mlir_arg.getType().cast(); + auto tensor_type = mlir::cast(mlir_arg.getType()); if (tensor_type.hasRank()) { mlir_arg.setType( GetTypeFromTFTensorShape(tensor_type.getShape(), resource_type)); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD index 545203ad20ea23..709a63bea84ebe 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD @@ -210,6 +210,7 @@ tf_cc_test( srcs = ["tf_dialect_to_executor_test.cc"], data = [ "testdata/empty_func.mlir", + "testdata/func_with_dead_ops.mlir", "testdata/invalid_executor.mlir", ], deps = [ @@ -220,10 +221,9 @@ tf_cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/lib/monitoring:test_utils", - "@local_tsl//tsl/platform:status", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc index 0e7e61999d8f2b..0f5680bda420d2 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc @@ -46,6 +46,7 @@ namespace tf2xla { namespace v2 { using ::tensorflow::monitoring::testing::CellReader; +using ::testing::Not; using ::testing::TestWithParam; using tpu::FunctionToHloArgs; using tpu::MlirToHloArgs; @@ -334,6 +335,41 @@ TEST(LegalizeTFTest, SuccessfullyCompilesModulesWithReturnValues) { ComputationProtoContains("opcode:.*constant")); } +TEST(LegalizeTFTest, SkipsTensorListSetItemIfDimensionsTooLarge) { + static constexpr char kTensorListSetItemDimensionTooLarge[] = R"( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main() -> tensor>> { + // unknown rank + %elem_shape = "tf.Const"() <{value = dense<-1> : tensor}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> tensor + // zero reserved elements + %num_elements = "tf.Const"() <{value = dense<0> : tensor}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> tensor + + %list = "tf.TensorListReserve"(%elem_shape, %num_elements) : (tensor, tensor) -> tensor>> + + %index = "tf.Const"() <{value = dense<0> : tensor}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> tensor + %element = "tf.Const"() <{value = dense<0.0> : tensor<64x1xbf16>}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> tensor<64x1xbf16> + // Results in a bad mismatch of shapes. + %updated_list = "tf.TensorListSetItem"(%list, %index, %element) : (tensor>>, tensor, tensor<64x1xbf16>) -> tensor>> + + return %updated_list : tensor>> + } + })"; + + auto compilation_result = CompileMlirModule( + kTensorListSetItemDimensionTooLarge, + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED); + + // Ensure that it compile + ASSERT_TRUE(compilation_result.ok()); + // Assert that the tensor list operation is lowered to something. + ASSERT_THAT(compilation_result, + Not(ComputationProtoContains("%.*= \"tf.TensorListSetItem"))); + // Assert that the tensor list operation is lowered to something that doesn't + // get stuck on a broken dynamic update slice. + ASSERT_THAT(compilation_result, + Not(ComputationProtoContains("%.*=.*DynamicUpdateSlice"))); +} + } // namespace v2 } // namespace tf2xla } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/func_with_dead_ops.mlir b/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/func_with_dead_ops.mlir new file mode 100644 index 00000000000000..f8dd51f4e12d3c --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/func_with_dead_ops.mlir @@ -0,0 +1,62 @@ +module attributes {tf.devices = {"/job:tpu_host_worker/replica:0/task:0/device:CPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:1", "/job:tpu_host_worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:1", "/job:tpu_host_worker/replica:0/task:1/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:2/device:CPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:1", "/job:tpu_host_worker/replica:0/task:2/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:3/device:CPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:1", "/job:tpu_host_worker/replica:0/task:3/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1847 : i32}} { + func.func @main(%arg0: tensor {tf._user_specified_name = "steps", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<*x!tf_type.resource>> {tf._user_specified_name = "899", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf_type.resource>> {tf._user_specified_name = "901", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg3: tensor<*x!tf_type.resource>> {tf._user_specified_name = "903", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf_type.resource>> {tf._user_specified_name = "905", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg5: tensor<*x!tf_type.resource>> {tf._user_specified_name = "907", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf_type.resource>> {tf._user_specified_name = "909", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg7: tensor<*x!tf_type.resource>> {tf._user_specified_name = "911", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg8: tensor<*x!tf_type.resource>> {tf._user_specified_name = "913", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg9: tensor<*x!tf_type.resource>> {tf._user_specified_name = "915", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg10: tensor<*x!tf_type.resource>> {tf._user_specified_name = "917", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg11: tensor<*x!tf_type.resource>> {tf._user_specified_name = "919", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg12: tensor<*x!tf_type.resource>> {tf._user_specified_name = "921", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg13: tensor<*x!tf_type.resource>> {tf._user_specified_name = "923", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg14: tensor<*x!tf_type.resource>> {tf._user_specified_name = "925", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg15: tensor<*x!tf_type.resource>> {tf._user_specified_name = "927", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg16: tensor<*x!tf_type.resource>> {tf._user_specified_name = "929", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg17: tensor<*x!tf_type.resource>> {tf._user_specified_name = "931", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg18: tensor<*x!tf_type.resource>> {tf._user_specified_name = "933", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg19: tensor<*x!tf_type.resource>> {tf._user_specified_name = "935", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg20: tensor<*x!tf_type.resource>> {tf._user_specified_name = "937", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg21: tensor<*x!tf_type.resource>> {tf._user_specified_name = "939", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> tensor attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "steps,unknown,unknown_0,unknown_1,unknown_2,unknown_3,unknown_4,unknown_5,unknown_6,unknown_7,unknown_8,unknown_9,unknown_10,unknown_11,unknown_12,unknown_13,unknown_14,unknown_15,unknown_16,unknown_17,unknown_18,unknown_19", outputs = "statefulpartitionedcall_RetVal"}} { + %0 = "tf.ReadVariableOp"(%arg19) : (tensor<*x!tf_type.resource>>) -> tensor<128x1024xf32> + %1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>>) -> tensor + %2 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf_type.resource>>) -> tensor + %3 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf_type.resource>>) -> tensor<1024xf32> + %4 = "tf.ReadVariableOp"(%arg3) : (tensor<*x!tf_type.resource>>) -> tensor<128x1024xf32> + %5 = "tf.ReadVariableOp"(%arg5) : (tensor<*x!tf_type.resource>>) -> tensor<1024x1xf32> + %6 = "tf.ReadVariableOp"(%arg20) : (tensor<*x!tf_type.resource>>) -> tensor<1024xf32> + %7 = "tf.ReadVariableOp"(%arg21) : (tensor<*x!tf_type.resource>>) -> tensor<1024x1xf32> + %8 = "tf.ReadVariableOp"(%arg6) : (tensor<*x!tf_type.resource>>) -> tensor + %9 = "tf.Const"() <{value = dense<"test"> : tensor<3x!tf_type.string>}> : () -> tensor<3x!tf_type.string> + %cst = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor + %11:4 = "tf.Split"(%cst, %0) {num_split = 4 : i32} : (tensor, tensor<128x1024xf32>) -> (tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>) + %cst_0 = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor + %12:4 = "tf.Split"(%cst_0, %4) {num_split = 4 : i32} : (tensor, tensor<128x1024xf32>) -> (tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>) + %cst_1 = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor + %cst_2 = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor + %13:20 = tf_device.replicate {devices = {TPU_REPLICATED_CORE_0 = ["/job:tpu_host_worker/replica:0/task:0/device:TPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:0"], TPU_REPLICATED_CORE_1 = ["/job:tpu_host_worker/replica:0/task:0/device:TPU:1", "/job:tpu_host_worker/replica:0/task:2/device:TPU:1"], TPU_REPLICATED_CORE_2 = ["/job:tpu_host_worker/replica:0/task:1/device:TPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:0"], TPU_REPLICATED_CORE_3 = ["/job:tpu_host_worker/replica:0/task:1/device:TPU:1", "/job:tpu_host_worker/replica:0/task:3/device:TPU:1"], TPU_REPLICATED_HOST_0 = ["/job:tpu_host_worker/replica:0/task:0/device:CPU:0", "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"], TPU_REPLICATED_HOST_1 = ["/job:tpu_host_worker/replica:0/task:0/device:CPU:0", "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"], TPU_REPLICATED_HOST_2 = ["/job:tpu_host_worker/replica:0/task:1/device:CPU:0", "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"], TPU_REPLICATED_HOST_3 = ["/job:tpu_host_worker/replica:0/task:1/device:CPU:0", "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"]}, n = 2 : i32} { + %16:40 = "tf_device.parallel_execute"() ({ + %19:10 = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}> ({ + %20:10 = "tf.TPUExecute"(%arg0, %11#0, %1, %2, %3, %12#0, %5, %6, %7, %8, %9) : (tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor, tensor<3x!tf_type.string>) -> (tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor) + tf_device.return %20#0, %20#1, %20#2, %20#3, %20#4, %20#5, %20#6, %20#7, %20#8, %20#9 : tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor + }) : () -> (tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor) + tf_device.return %19#0, %19#1, %19#2, %19#3, %19#4, %19#5, %19#6, %19#7, %19#8, %19#9 : tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor + }, { + %19:10 = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}> ({ + %20:10 = "tf.TPUExecute"(%arg0, %11#1, %1, %2, %3, %12#1, %5, %6, %7, %8, %9) : (tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor, tensor<3x!tf_type.string>) -> (tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor) + tf_device.return %20#0, %20#1, %20#2, %20#3, %20#4, %20#5, %20#6, %20#7, %20#8, %20#9 : tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor + }) : () -> (tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor) + tf_device.return %19#0, %19#1, %19#2, %19#3, %19#4, %19#5, %19#6, %19#7, %19#8, %19#9 : tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor + }, { + %19:10 = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_2"}> ({ + %20:10 = "tf.TPUExecute"(%arg0, %11#2, %1, %2, %3, %12#2, %5, %6, %7, %8, %9) : (tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor, tensor<3x!tf_type.string>) -> (tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor) + tf_device.return %20#0, %20#1, %20#2, %20#3, %20#4, %20#5, %20#6, %20#7, %20#8, %20#9 : tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor + }) : () -> (tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor) + tf_device.return %19#0, %19#1, %19#2, %19#3, %19#4, %19#5, %19#6, %19#7, %19#8, %19#9 : tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor + }, { + %19:10 = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_3"}> ({ + %20:10 = "tf.TPUExecute"(%arg0, %11#3, %1, %2, %3, %12#3, %5, %6, %7, %8, %9) : (tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor, tensor<3x!tf_type.string>) -> (tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor) + tf_device.return %20#0, %20#1, %20#2, %20#3, %20#4, %20#5, %20#6, %20#7, %20#8, %20#9 : tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor + }) : () -> (tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor) + tf_device.return %19#0, %19#1, %19#2, %19#3, %19#4, %19#5, %19#6, %19#7, %19#8, %19#9 : tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor + }) : () -> (tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor, tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor, tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor, tensor, tensor<32x1024xf32>, tensor, tensor, tensor<1024xf32>, tensor<32x1024xf32>, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor) + %17 = "tf.Concat"(%cst_1, %16#5, %16#15, %16#25, %16#35) : (tensor, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>) -> tensor<128x1024xf32> + %18 = "tf.Concat"(%cst_2, %16#1, %16#11, %16#21, %16#31) : (tensor, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>) -> tensor<128x1024xf32> + tf_device.return %16#0, %16#9, %16#8, %16#7, %16#6, %17, %16#4, %16#3, %16#2, %18 : tensor, tensor, tensor<1024x1xf32>, tensor<1024xf32>, tensor<1024x1xf32>, tensor<128x1024xf32>, tensor<1024xf32>, tensor, tensor, tensor<128x1024xf32> + } + "tf.AssignVariableOp"(%arg19, %13#18) <{validate_shape = false}> : (tensor<*x!tf_type.resource>>, tensor<128x1024xf32>) -> () + "tf.AssignVariableOp"(%arg1, %13#16) <{validate_shape = false}> : (tensor<*x!tf_type.resource>>, tensor) -> () + "tf.AssignVariableOp"(%arg2, %13#14) <{validate_shape = false}> : (tensor<*x!tf_type.resource>>, tensor) -> () + "tf.AssignVariableOp"(%arg4, %13#12) <{validate_shape = false}> : (tensor<*x!tf_type.resource>>, tensor<1024xf32>) -> () + "tf.AssignVariableOp"(%arg3, %13#10) <{validate_shape = false}> : (tensor<*x!tf_type.resource>>, tensor<128x1024xf32>) -> () + "tf.AssignVariableOp"(%arg5, %13#8) <{validate_shape = false}> : (tensor<*x!tf_type.resource>>, tensor<1024x1xf32>) -> () + "tf.AssignVariableOp"(%arg20, %13#6) <{validate_shape = false}> : (tensor<*x!tf_type.resource>>, tensor<1024xf32>) -> () + "tf.AssignVariableOp"(%arg21, %13#4) <{validate_shape = false}> : (tensor<*x!tf_type.resource>>, tensor<1024x1xf32>) -> () + "tf.AssignVariableOp"(%arg6, %13#2) <{validate_shape = true}> {_has_manual_control_dependencies = true} : (tensor<*x!tf_type.resource>>, tensor) -> () + %14 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource>>) -> tensor + %15 = "tf.Identity"(%14) {device = ""} : (tensor) -> tensor + return %15 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc index c92fd85d3567b4..cd13e869e811dd 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc @@ -88,6 +88,8 @@ void AddTfDialectToExecutorPasses(OpPassManager &pm) { pm.addNestedPass(mlir::TFTPU::CreateTPUDevicePropagationPass()); pm.addNestedPass(mlir::TFTPU::CreateTPUColocateSplitsPass()); pm.addPass(mlir::createSymbolDCEPass()); + pm.addNestedPass( + mlir::tf_executor::CreateTFExecutorGraphPruningPass()); if (tensorflow::GetMlirCommonFlags() ->tf_mlir_enable_convert_control_to_data_outputs_pass) { bool composite_tpuexecute_side_effects = diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc index 0c64dd3dcbe1a3..897c800d9e4cd7 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc @@ -15,12 +15,16 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h" +#include + #include #include #include #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -30,7 +34,6 @@ limitations under the License. #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/resource_loader.h" #include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/status.h" namespace tensorflow { namespace tf2xla { @@ -53,6 +56,16 @@ std::string TestDataPath() { "tensorflow/compiler/mlir/tf2xla/api/v2/testdata/"); } +size_t CountSubstring(absl::string_view str, absl::string_view substr) { + size_t count = 0; + size_t idx = str.find(substr); + while (idx != std::string::npos) { + count++; + idx = str.find(substr, idx + 1); + } + return count; +} + class TensorflowDialectToExecutorTest : public ::testing::Test { public: TensorflowDialectToExecutorTest() { @@ -100,6 +113,23 @@ TEST_F(TensorflowDialectToExecutorTest, ErrorsWhenCannotConvert) { EXPECT_EQ(compilation_status.Delta(kExportFailed), 1); } +TEST_F(TensorflowDialectToExecutorTest, PrunesDeadOps) { + CellReader compilation_status(kExportStreamzName); + + TF_ASSERT_OK(CreateMlirModule("func_with_dead_ops.mlir")); + + TF_EXPECT_OK(ExportFromTensorflowDialectToExecutor(*mlir_module_)); + + std::string module_dump; + llvm::raw_string_ostream raw_stream(module_dump); + mlir_module_->print(raw_stream); + + EXPECT_EQ(compilation_status.Delta(kExportSuccess), 1); + EXPECT_EQ(compilation_status.Delta(kExportFailed), 0); + EXPECT_EQ( + CountSubstring(module_dump, "tf_executor.island wraps \"tf.Concat\""), 2); +} + } // namespace } // namespace v2 } // namespace tf2xla diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc index d0909c452a0325..d702cf308f8f80 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc @@ -35,8 +35,6 @@ namespace internal { using mlir::OpPassManager; using mlir::func::FuncOp; -// LINT.IfChange(replicated_bridge_passes) - // Adds replicated Bridge clustering pipeline passes to the given pass_manager. // Does not run them. void AddReplicatedBridgeClusteringPipelinePasses(OpPassManager& pm, @@ -151,7 +149,8 @@ void AddReplicatedBridgeClusteringPipelinePasses(OpPassManager& pm, pm.addPass(mlir::TFDevice::CreateClusterOutliningPass()); pm.addPass(mlir::TFTPU::CreateTPUResourceReadForWritePass()); pm.addPass(mlir::TFDevice::CreateMarkInputOutputAliasesPass()); - pm.addPass(mlir::TFTPU::CreateTPUShardingIdentificationPass()); + pm.addPass( + tensorflow::tf2xla::internal::CreateTPUShardingIdentificationPass()); pm.addNestedPass( mlir::TFTPU::CreateTPUResourceReadsWritesPartitioningPass()); pm.addPass(mlir::TFDevice::CreateAnnotateParameterReplicationPass()); @@ -163,12 +162,9 @@ void AddReplicatedBridgeClusteringPipelinePasses(OpPassManager& pm, pm.addNestedPass( tensorflow::tf2xla::internal::CreateVerifyClusteringPass()); } -// LINT.ThenChange(:non_replicated_bridge_passes) void NoCanonicalization(OpPassManager& pm) {} -// LINT.IfChange(non_replicated_bridge_passes) - // Same as above but for non-replicated Bridge. void AddNonReplicatedBridgeClusteringPipelinePasses(OpPassManager& pm) { // The following ops must be preserved regardless of reachability. Ideally, @@ -218,7 +214,6 @@ void AddNonReplicatedBridgeClusteringPipelinePasses(OpPassManager& pm) { pm.addNestedPass( tensorflow::tf2xla::internal::CreateVerifyClusteringPass()); } -// LINT.ThenChange(:replicated_bridge_passes) }; // namespace internal }; // namespace tf2xla diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc index 7bf4c74e094af5..4adb8ebd160d57 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc @@ -66,7 +66,7 @@ LogicalResult HasAttr( // This is not expected to happen in practice if (!status.ok()) { LOG(ERROR) << "Failed to parse " << func_name << ": " - << tsl::NullTerminatedMessage(status); + << absl::StatusMessageAsCStr(status); return failure(); } if (predicate(*func_body->graph)) { diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD index 9641e092815b58..4ef78ef8d9b18d 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -30,6 +30,7 @@ cc_library( ":hoist_broadcast_read", ":mark_ops_for_outside_compilation", ":tpu_cluster_formation", + ":tpu_sharding_identification_pass", ":verify_clustering_pass", ":xla_broadcast", ":xla_cluster_formation", @@ -322,6 +323,47 @@ cc_library( ], ) +cc_library( + name = "tpu_sharding_identification_pass", + srcs = ["tpu_sharding_identification_pass.cc"], + textual_hdrs = [ + "clustering_passes.h.inc", + ], + deps = [ + ":clustering_passes_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:string_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_traits", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", + "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", + "//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_pass_inc_gen", + "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", + "//tensorflow/compiler/mlir/tf2xla/transforms:legalization_op_config", + "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/client:sharding_builder", + ], +) + cc_library( name = "hoist_broadcast_read", srcs = ["hoist_broadcast_read.cc"], diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h index fb6e32ac377b79..85703c2306ad6b 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h @@ -62,11 +62,17 @@ CreateHoistBroadcastReadPass(); std::unique_ptr> CreateXlaBroadcastPass(); +// Creates a pass that identifies XLASharding ops in launch op for TPU +// computation. +std::unique_ptr> +CreateTPUShardingIdentificationPass(); + #define GEN_PASS_REGISTRATION #define GEN_PASS_DECL_MARKOPSFOROUTSIDECOMPILATIONPASS #define GEN_PASS_DECL_TPUCLUSTERFORMATIONPASS #define GEN_PASS_DECL_TPUEXTRACTHEADTAILOUTSIDECOMPILATIONPASS #define GEN_PASS_DECL_TPUEXTRACTOUTSIDECOMPILATIONPASS +#define GEN_PASS_DECL_TPUSHARDINGIDENTIFICATIONPASS #define GEN_PASS_DECL_VERIFYCLUSTERINGPASS #define GEN_PASS_DECL_XLACLUSTERFORMATIONPASS #include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td index 2f617f7c154935..c1c34561ff0eb7 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td @@ -390,3 +390,55 @@ def XlaBroadcastPass : Pass<"tf-xla-broadcast", "mlir::func::FuncOp"> { let constructor = "tensorflow::tf2xla::internal::CreateXlaBroadcastPass()"; let dependentDialects = ["mlir::tf_device::TensorFlowDeviceDialect"]; } + +def TPUShardingIdentificationPass : Pass<"tf-tpu-sharding-identification", "ModuleOp"> { + let summary = "Identifies and handles inputs/outputs of TPU computation that is " + "sharded across logical cores."; + let constructor = "tensorflow::tf2xla::internal::CreateTPUShardingIdentificationPass()"; + let description = [{ + Bubbles up sharding configuration from `cluster_func` regions into + the attributes of `cluster_func`. This is done by parsing the + `XlaSharding` / `TPUPartitionedOutput` / `TPUPartitionedInput` ops inside + `cluster_func`. + + For example, given the following `cluster_func` wrapping `func`: + + ```mlir + func @test(%arg0: tensor<*xi32>) { + "tf_device.cluster_func"(%arg0) { + func = @func, + step_marker_location = ""} : (tensor<*xi32>) -> tensor<*xi32> + return + } + + func @func(%arg0: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "\01\02\03", + sharding = "\01\02\03"} : (tensor<*xi32>) -> tensor<*xi32> + %1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>) + return %1 : tensor<*xi32> + } + ``` + + Now, cluster_func receives the following `*_sharding_configuration` + attributes, and `func` receives the mhlo.sharding attribute: + + ```mlir + func @test(%arg0: tensor<*xi32>) { + %0 = "tf_device.cluster_func"(%arg0) { + func = @func, + input_sharding_configuration = ["\01\02\03"], + output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], + step_marker_location = ""} : (tensor<*xi32>) -> tensor<*xi32> + return + } + func @func(%arg0: tensor<*xi32> {mhlo.sharding = "\01\02\03"}) -> + (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) { + %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "\01\02\03", sharding = "\01\02\03"} : (tensor<*xi32>) -> tensor<*xi32> + %1 = "tf.A"(%0) : (tensor<*xi32>) -> tensor<*xi32> + return %1 : tensor<*xi32> + } + ``` + }]; +} + + diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_head_tail_outside_compilation.cc index ad85310291c146..e0dc7bda1f9c86 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_head_tail_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_head_tail_outside_compilation.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" @@ -78,7 +79,7 @@ bool HasOutsideCompilationAttribute(Operation* op) { // Finds op that created a given value. If the value is a BlockArgument, this // returns the owner of the Block. Operation* GetOpOfValue(Value value) { - if (auto block_arg = value.dyn_cast()) + if (auto block_arg = mlir::dyn_cast(value)) return block_arg.getOwner()->getParentOp(); return value.getDefiningOp(); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_outside_compilation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_outside_compilation.cc index 6bc3468a2729e3..66340e57012e69 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_outside_compilation.cc @@ -97,7 +97,6 @@ constexpr char kDeviceAttr[] = "device"; constexpr char kHostFunctionAttr[] = "host_func"; constexpr char kXlaMapOutsideCompilationAttr[] = "_xla_map_outside_compilation"; constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; -constexpr char kNoReplicationCluster[] = "__no_replication_cluster"; #define GEN_PASS_DEF_EXTRACTOUTSIDECOMPILATIONPASS #include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" @@ -202,7 +201,6 @@ Operation* ApplyXlaHostTransferAttr(Operation* op, OpBuilder& builder) { Operation* CreateSendFromHostOp(OpBuilder& builder, Location loc, ValueRange inputs, Value compilation_key, Value device_ordinal, - int default_device_ordinal, StringAttr device_type_attr, llvm::StringRef communication_key) { if (device_ordinal) @@ -218,8 +216,7 @@ Operation* CreateSendFromHostOp(OpBuilder& builder, Location loc, loc, inputs, /*dynamic_key=*/compilation_key, builder.getStringAttr(communication_key), - /*device_ordinal=*/builder.getI64IntegerAttr(default_device_ordinal), - device_type_attr), + /*device_ordinal=*/builder.getI64IntegerAttr(0), device_type_attr), builder); } @@ -227,8 +224,7 @@ Operation* CreateSendFromHostOp(OpBuilder& builder, Location loc, // present, a tf._XlaRecvAtHostV2 op is created instead. Operation* CreateRecvAtHostOp(OpBuilder& builder, Location loc, TypeRange output_types, Value compilation_key, - Value device_ordinal, int default_device_ordinal, - StringAttr device_type_attr, + Value device_ordinal, StringAttr device_type_attr, llvm::StringRef communication_key) { if (device_ordinal) return ApplyXlaHostTransferAttr( @@ -241,8 +237,7 @@ Operation* CreateRecvAtHostOp(OpBuilder& builder, Location loc, builder.create( loc, output_types, /*dynamic_key=*/compilation_key, builder.getStringAttr(communication_key), - /*device_ordinal=*/builder.getI64IntegerAttr(default_device_ordinal), - device_type_attr), + /*device_ordinal=*/builder.getI64IntegerAttr(0), device_type_attr), builder); } @@ -386,7 +381,7 @@ llvm::SmallSetVector GetStaticExternalOperands( } continue; } - auto block_arg = v.cast(); + auto block_arg = mlir::cast(v); if (block_arg.getParentRegion() == op->getParentRegion()) external_values.insert(v); } @@ -475,7 +470,7 @@ void GetExternalOutputs(const llvm::SmallSetVector& cluster_ops, LogicalResult GetShardShapedType(Operation* context_op, int num_cores_per_replica, Type full_type, Type& shard_type) { - RankedTensorType ranked_type = full_type.dyn_cast(); + RankedTensorType ranked_type = mlir::dyn_cast(full_type); if (!ranked_type) return context_op->emitOpError() << "A map_outside_compilation op's input and output types must be " @@ -587,7 +582,8 @@ LogicalResult CreateHostComputeMap( // Convert MANUAL sharded outputs to split sharded outputs. for (auto [full_type, out] : llvm::zip(full_output_types, host_compute.getResults())) { - RankedTensorType full_type_ranked = full_type.dyn_cast(); + RankedTensorType full_type_ranked = + mlir::dyn_cast(full_type); if (!full_type_ranked) return original_op->emitOpError() << "map_outside_compilation must have ranked outputs"; @@ -775,9 +771,9 @@ Operation* CreateHostOps(ArrayRef clustered_ops, ArrayRef external_operands, ArrayRef external_outputs, Operation* host_insertion_point, Value compilation_key, - Value device_ordinal, int default_device_ordinal, - StringAttr device_type_attr, OpBuilder& builder, - Operation& op, std::string args_communication_key, + Value device_ordinal, StringAttr device_type_attr, + OpBuilder& builder, Operation& op, + std::string args_communication_key, std::string retvals_communication_key, SmallVector& host_ops) { builder.setInsertionPoint(host_insertion_point); @@ -787,7 +783,7 @@ Operation* CreateHostOps(ArrayRef clustered_ops, Operation* recv_at_host = CreateRecvAtHostOp( builder, op.getLoc(), host_operand_types, compilation_key, device_ordinal, - default_device_ordinal, device_type_attr, args_communication_key); + device_type_attr, args_communication_key); if (!external_operands.empty()) host_ops.push_back(recv_at_host); Operation* after_op = recv_at_host; @@ -801,7 +797,7 @@ Operation* CreateHostOps(ArrayRef clustered_ops, if (!external_outputs.empty()) { Operation* send_from_host = CreateSendFromHostOp( builder, op.getLoc(), external_outputs, compilation_key, device_ordinal, - default_device_ordinal, device_type_attr, retvals_communication_key); + device_type_attr, retvals_communication_key); host_ops.push_back(send_from_host); } @@ -855,8 +851,8 @@ LogicalResult MoveToHostSingleCluster( llvm::SmallVector& core_to_mapping, ArrayRef core_to_host_insertion_point, ArrayRef core_to_compilation_key, - ArrayRef core_to_device_ordinal, int default_device_ordinal, - StringAttr device_type_attr, bool is_map_oc, int num_cores_per_replica, + ArrayRef core_to_device_ordinal, StringAttr device_type_attr, + bool is_map_oc, int num_cores_per_replica, std::string& common_split_sharding, int& communication_key_index) { OpBuilder builder(core_to_host_insertion_point[0]); Operation& op = *clustered_ops.back(); @@ -891,8 +887,8 @@ LogicalResult MoveToHostSingleCluster( clustered_ops, external_operands, external_outputs, core_to_host_insertion_point[0], core_to_compilation_key[0], core_to_device_ordinal.empty() ? nullptr : core_to_device_ordinal[0], - default_device_ordinal, device_type_attr, builder, op, - args_communication_key, retvals_communication_key, host0_ops); + device_type_attr, builder, op, std::move(args_communication_key), + std::move(retvals_communication_key), host0_ops); if (external_operands.empty()) { recv_at_host->erase(); @@ -960,9 +956,8 @@ LogicalResult MoveToHostMultiCluster( mlir::tf_device::ClusterOp device_cluster, Block* src, ArrayRef core_to_host_insertion_point, ArrayRef core_to_compilation_key, - ArrayRef core_to_device_ordinal, int default_device_ordinal, - bool control_above, std::optional& is_map_oc, - int& communication_key_index, + ArrayRef core_to_device_ordinal, bool control_above, + std::optional& is_map_oc, int& communication_key_index, llvm::SmallVector* return_value_from_host = nullptr) { const int num_cores_per_replica = core_to_host_insertion_point.size(); // common_split_sharding is set upon the first use of map_outside_compilation. @@ -1005,8 +1000,8 @@ LogicalResult MoveToHostMultiCluster( clustered_ops.getArrayRef(), external_operands.getArrayRef(), external_outputs.getArrayRef(), core_to_mapping, core_to_host_insertion_point, core_to_compilation_key, - core_to_device_ordinal, default_device_ordinal, device_type_attr, - *is_map_oc, num_cores_per_replica, common_split_sharding, + core_to_device_ordinal, device_type_attr, *is_map_oc, + num_cores_per_replica, common_split_sharding, communication_key_index))) return mlir::failure(); clustered_ops.clear(); @@ -1033,8 +1028,8 @@ LogicalResult MoveToHostMultiCluster( clustered_ops.getArrayRef(), external_operands.getArrayRef(), external_outputs.getArrayRef(), core_to_mapping, core_to_host_insertion_point, core_to_compilation_key, - core_to_device_ordinal, default_device_ordinal, device_type_attr, - *is_map_oc, num_cores_per_replica, common_split_sharding, + core_to_device_ordinal, device_type_attr, *is_map_oc, + num_cores_per_replica, common_split_sharding, communication_key_index))) return mlir::failure(); clustered_ops.clear(); @@ -1068,7 +1063,6 @@ void GetReturnValueFromDevice( LogicalResult DecomposeControlFlow(mlir::tf_device::ClusterOp device_cluster, ArrayRef core_to_compilation_key, ArrayRef core_to_device_ordinal, - int default_device_ordinal, int& communication_key_index, std::optional& is_map_oc) { auto result = device_cluster.GetBody().walk([&](Operation* op) { @@ -1080,15 +1074,13 @@ LogicalResult DecomposeControlFlow(mlir::tf_device::ClusterOp device_cluster, device_cluster, &if_op.getThenBranch().front(), {host_if.getThenBranch().front().getTerminator()}, core_to_compilation_key, core_to_device_ordinal, - default_device_ordinal, /*control_above=*/true, is_map_oc, - communication_key_index))) + /*control_above=*/true, is_map_oc, communication_key_index))) return WalkResult::interrupt(); if (failed(MoveToHostMultiCluster( device_cluster, &if_op.getElseBranch().front(), {host_if.getElseBranch().front().getTerminator()}, core_to_compilation_key, core_to_device_ordinal, - default_device_ordinal, /*control_above=*/true, is_map_oc, - communication_key_index))) + /*control_above=*/true, is_map_oc, communication_key_index))) return WalkResult::interrupt(); // Mark op as stateful due to side-effecting communication ops. if_op->setAttr("is_stateless", builder.getBoolAttr(false)); @@ -1118,7 +1110,7 @@ LogicalResult DecomposeControlFlow(mlir::tf_device::ClusterOp device_cluster, builder.setInsertionPointToEnd(&cond.front()); auto recv_condition_at_host = CreateRecvAtHostOp( builder, while_op.getLoc(), TypeRange{condition.getType()}, - core_to_compilation_key[0], device_ordinal0, default_device_ordinal, + core_to_compilation_key[0], device_ordinal0, device_cluster->getAttrOfType( mlir::TF::kCompileDeviceTypeAttr), condition_send_recv_key); @@ -1128,15 +1120,14 @@ LogicalResult DecomposeControlFlow(mlir::tf_device::ClusterOp device_cluster, if (failed(MoveToHostMultiCluster( device_cluster, &while_op.getCond().front(), {recv_condition_at_host}, core_to_compilation_key, - core_to_device_ordinal, default_device_ordinal, + core_to_device_ordinal, /*control_above=*/true, is_map_oc, communication_key_index))) return WalkResult::interrupt(); if (failed(MoveToHostMultiCluster( device_cluster, &while_op.getBody().front(), {host_while.getBody().front().getTerminator()}, core_to_compilation_key, core_to_device_ordinal, - default_device_ordinal, /*control_above=*/true, is_map_oc, - communication_key_index))) + /*control_above=*/true, is_map_oc, communication_key_index))) return WalkResult::interrupt(); // Mark op as stateful due to side-effecting communication ops. while_op->setAttr("is_stateless", builder.getBoolAttr(false)); @@ -1160,45 +1151,6 @@ void RemoveOutsideCompilation(mlir::tf_device::LaunchOp host_launch_op) { }); } -// This method extracts default ordinal or default device core associated with a -// host. -// If the cluster has replication attribute and it is not empty, then it means -// it is replicated case and then NO ordinal info is extracted but -// if it is non replicated cluster and there is a device attr with some -// non-empty device, then that device's ordinal (0 out of TPU:0 and -// 1 out of TPU:1) is extracted and the default ordinal is set to this value. -LogicalResult GetDefaultDeviceOrdinal(mlir::tf_device::ClusterOp device_cluster, - int& default_ordinal) { - bool has_replication = - device_cluster->hasAttr(mlir::TF::kReplicationInfoAttr); - - std::string replication_info; - if (has_replication) { - replication_info = - device_cluster - ->getAttrOfType(mlir::TF::kReplicationInfoAttr) - .str(); - } - if (replication_info == kNoReplicationCluster || replication_info.empty()) { - has_replication = false; - } - if (!has_replication && - device_cluster->hasAttrOfType(kDeviceAttr) && - !device_cluster->getAttrOfType(kDeviceAttr).str().empty()) { - int64_t ordinal = 0; - mlir::LogicalResult result = tensorflow::GetDeviceOrdinalFromDeviceString( - mlir::UnknownLoc::get(device_cluster.getContext()), - device_cluster->getAttrOfType(kDeviceAttr).str(), &ordinal); - if (succeeded(result)) { - default_ordinal = ordinal; - } else { - return device_cluster.emitError() - << " could not find ordinal for the given device"; - } - } - return mlir::success(); -} - // The results of parallel executes is the combination of return values from // both host and device. llvm::SmallVector GetParallelExecuteResultsTypes( @@ -1485,19 +1437,15 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( } builder.setInsertionPoint(tmp_parallel_execute_op); - int default_device_ordinal = 0; - if (failed(GetDefaultDeviceOrdinal(device_cluster, default_device_ordinal))) { - return mlir::failure(); - } // communication_key_index is part of the message identifier and is // incremented for each _XlaHostComputeMlir. int communication_key_index = 0; // Decompose control flow into device and host control flow when outside // compilation is included. - if (failed(DecomposeControlFlow( - device_cluster, core_to_compilation_key, core_to_device_ordinal, - default_device_ordinal, communication_key_index, is_map_oc))) + if (failed(DecomposeControlFlow(device_cluster, core_to_compilation_key, + core_to_device_ordinal, + communication_key_index, is_map_oc))) return mlir::failure(); // Move all outside compiled ops including control flow to tmp host launch. @@ -1505,7 +1453,7 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( if (failed(MoveToHostMultiCluster( device_cluster, &device_cluster.GetBody(), core_to_host_insertion_point, core_to_compilation_key, - core_to_device_ordinal, default_device_ordinal, + core_to_device_ordinal, /*control_above=*/false, is_map_oc, communication_key_index, &returns_from_host))) return mlir::failure(); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc index 732bae8c67b018..f16df445439084 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc @@ -72,7 +72,7 @@ Operation* GetAncestorBelow(Operation* descendant, Operation* ancestor) { // `is_cpu_read` is set to `true` iff `read` is on a resource with device type // CPU. LogicalResult IsCpuRead(FuncOp func, ReadVariableOp read, bool& is_cpu_read) { - if (auto arg = read->getOperand(0).dyn_cast()) { + if (auto arg = mlir::dyn_cast(read->getOperand(0))) { if (arg.getOwner() != &(func.front())) { is_cpu_read = false; return success(); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass.cc index bc0e25f505e11b..d6c92101bf608a 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass.cc @@ -30,9 +30,7 @@ namespace internal { namespace { -using llvm::DenseSet; using mlir::Operation; -using mlir::TypeID; using mlir::WalkResult; #define GEN_PASS_DEF_INPUTLOWERINGMETRICSPASS diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc index dde1fd4514d719..7308669b6359cb 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Rewrite/PatternApplicator.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -161,6 +162,8 @@ void AddSupportedFunctionalOps(MLIRContext* context, OperationName(mlir::TF::WhileRegionOp::getOperationName(), context)); supported_ops->insert( OperationName(mlir::TF::XlaCallModuleOp::getOperationName(), context)); + supported_ops->insert( + OperationName(mlir::TF::XlaHostComputeOp::getOperationName(), context)); supported_ops->insert( OperationName(mlir::TF::XlaReduceOp::getOperationName(), context)); supported_ops->insert( @@ -236,13 +239,13 @@ void AddRewrittenCompositeOps(MLIRContext* context, } bool IsStringType(Type type) { - if (type.isa()) return true; + if (mlir::isa(type)) return true; - auto sub_type = type.dyn_cast(); + auto sub_type = mlir::dyn_cast(type); if (!sub_type) return false; bool has_string = llvm::any_of(sub_type.GetSubtypes(), [](TensorType type) { - return type.getElementType().isa(); + return mlir::isa(type.getElementType()); }); return has_string; } @@ -288,7 +291,8 @@ bool IsSupportedOp(Operation& op, } bool IsVariant(Value value) { - return getElementTypeOrSelf(value.getType()).isa(); + return mlir::isa( + getElementTypeOrSelf(value.getType())); } bool HasOutsideCompiledAncestor(Operation* op) { diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc index b600c865661d58..e76e2e3bc86b14 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc @@ -51,6 +51,7 @@ limitations under the License. #include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" @@ -90,7 +91,6 @@ constexpr llvm::StringRef kNumCoresPerReplicaAttr = "num_cores_per_replica"; constexpr llvm::StringRef kNumReplicasAttr = "num_replicas"; constexpr llvm::StringRef kMirroredVariableIndicesAttr = "_mirrored_variable_indices"; -constexpr llvm::StringRef kNoReplicationCluster = "__no_replication_cluster"; constexpr llvm::StringRef kBadReplicateInfoAttrMsg = "requires '_replication_info' string attribute"; @@ -142,7 +142,7 @@ LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) { return metadata_op.emitError() << kBadReplicateInfoAttrMsg; auto replication_info_attr_str = - replication_info_attr.dyn_cast(); + mlir::dyn_cast(replication_info_attr); if (!replication_info_attr_str || replication_info_attr_str.getValue().empty()) return metadata_op.emitError() << kBadReplicateInfoAttrMsg; @@ -171,39 +171,50 @@ struct OpDevice { std::string device; }; -// Collects and clusters ops either based on `_replication_info` attribute -// (replicated case) or using one single cluster (non-replicated case). Also -// sets `device_type` if there is any cluster (note that the device type must be -// unique, otherwise we emit an error). -// Returns an error in case of invalid compilation or replication attribute(s). -LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters, - std::string& device_type, - std::string& device) { - bool has_replicated_compiled_op = false; - bool has_non_replicated_compiled_op = false; - bool has_local_device_name_collisions = false; +LogicalResult HasValidDeviceTypeAttribute(Block* block) { // Use ordered set here to make error message below deterministic. std::set device_types; - absl::flat_hash_map devices; + for (Operation& op : *block) { + // Collect device types which currently must be consistent per block + // (checked later). + if (auto device_type_attr = + op.getAttrOfType(mlir::TF::kCompileDeviceTypeAttr)) { + // tf.StatefulPartitionedCall ops with and without + // _tpu_replicate attributes may exist in the same graph. Ops without + // the attribute but with _XlaMustCompile=true would have + // _xla_compile_device_type="" after + // CanonicalizeCompileAndReplicateAttributesPass. Skip empty value here. + if (!device_type_attr.getValue().empty()) { + device_types.insert(device_type_attr); + } + } + } + + if (device_types.size() > 1) { + return block->getParentOp()->emitError() + << "found different '" << mlir::TF::kCompileDeviceTypeAttr + << "' attribute values (" << llvm::join(device_types, ",") + << ") in same block which is not supported"; + } + return success(); +} + +// Collects and clusters ops based on `_replication_info` attribute. Returns +// an error in case of invalid compilation or replication attribute(s). +LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) { + LogicalResult result = HasValidDeviceTypeAttribute(block); + if (failed(result)) return result; + for (Operation& op : *block) { LogicalResult result = mlir::TF::HasValidCompilationAndReplicationAttributes(op); if (failed(result)) return result; - // Collect device types which currently must be consistent per block - // (checked later). + // Skip ops with non-TPU device type, they are handled elsewhere. auto device_type_attr = op.getAttrOfType(mlir::TF::kCompileDeviceTypeAttr); if (device_type_attr) { - // Some graphs in TPU bridge may have both tf.StatefulPartitionedCall - // ops with and without _tpu_replicate attributes. As a result, the ops - // without such attribute would have _xla_compile_device_type="" after - // CanonicalizeCompileAndReplicateAttributesPass, if they also had - // _XlaMustCompile = true before the pass. We should filter out such - // unspecified device type here. if (device_type_attr.getValue().empty()) continue; - device_types.insert(device_type_attr); - // Stop here for ops with non-TPU devices, they are handled elsewhere. if (device_type_attr.getValue() != mlir::TF::kTpuDevice) continue; } @@ -213,105 +224,10 @@ LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters, // `HasValidCompilationAndReplicationAttributes` above, assert here for // documentation and to avoid breakage when that function is changed. assert(op.hasAttr(mlir::TF::kCompileDeviceTypeAttr)); - has_replicated_compiled_op = true; auto attr = op.getAttrOfType(mlir::TF::kReplicationInfoAttr); auto it = clusters->try_emplace(attr.getValue()); it.first->getSecond().insert(&op); - } else if (op.hasAttr(mlir::TF::kCompileDeviceTypeAttr)) { - // For non-replicated case, assume one cluster per block (in line with - // Framework behavior). - has_non_replicated_compiled_op = true; - auto it = clusters->try_emplace(kNoReplicationCluster); - it.first->getSecond().insert(&op); } - auto device_attr = op.getAttrOfType(kDeviceAttr); - std::string device_local_name; - bool is_tpu_device = false; - if (device_attr && !device_attr.str().empty()) { - tensorflow::DeviceNameUtils::ParsedName parsed; - if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(device_attr.str(), - &parsed)) { - op.emitWarning() << "Invalid device name " << device_attr.str(); - return mlir::failure(); - } - - device_local_name = - tensorflow::DeviceNameUtils::LocalName(parsed.type, parsed.id); - is_tpu_device = parsed.type == "TPU"; - } - - // Ignore non-TPU devices when clustering. - if (!is_tpu_device) { - continue; - } - - if (!has_replicated_compiled_op && !device_local_name.empty()) { - // It is possible that a device may be same Local Name but - // different fullname. Devices with same Local name are identical - // so they should only be added once in 'devices'. - // and we need the fullname which is longer since longer name has more - // information such as task, replica, job etc. An example fullname is - // "/job:foo_bar/replica:1/task:2/device:GPU:3" - if (devices.count(device_local_name)) { - std::string device1 = devices[device_local_name].device; - std::string device2 = device_attr.str(); - // Is either of the two devices just a substring of the other? If - // not, we treat them as different devices, and we have a collision. - if (device1.find(device2) == std::string::npos && - device2.find(device1) == std::string::npos) { - Operation* previous_op = devices[device_local_name].op; - has_local_device_name_collisions = true; - - LOG_FIRST_N(WARNING, 1) - << "Found two devices with same local name " << device_local_name - << " but conflicting fullname: " << device1 << " and " << device2 - << "."; - LOG_FIRST_N(WARNING, 1) - << "Previous assignment came from op: " - << tensorflow::OpAsString(*previous_op) - << ". Current op is: " << tensorflow::OpAsString(op); - } - // Always keep the longer name. - if (devices[device_local_name].device.size() < - device_attr.str().size()) { - devices[device_local_name] = {&op, device_attr.str()}; - } - } else { - devices.insert({device_local_name, {&op, device_attr.str()}}); - } - } - } - // Do some checks for unsupported cases. - if (has_replicated_compiled_op && has_non_replicated_compiled_op) { - return block->getParentOp()->emitError() - << "found mixed replicated and non-replicated compiled ops in same " - "block which is not supported"; - } - if (device_types.size() > 1) { - return block->getParentOp()->emitError() - << "found different '" << mlir::TF::kCompileDeviceTypeAttr - << "' attribute values (" << llvm::join(device_types, ",") - << ") in same block which is not supported"; - } - if (!has_replicated_compiled_op) { - if (devices.size() > 1) { - LOG(WARNING) << "found different devices for no replication: "; - for (const auto& device_names : devices) { - LOG(WARNING) << device_names.first << ", " - << device_names.second.device; - } - } else if (has_local_device_name_collisions) { - LOG(WARNING) << "Not assigning device because of conflicting fullnames."; - } else if (devices.size() == 1 && - absl::StrContains(devices.begin()->second.device, "TPU:")) { - device = devices.begin()->second.device; - } - } - if (!clusters->empty()) { - // Note that for size < 1 we shouldn't have any cluster while for size > 1 - // we should have returned with an error above. - assert(device_types.size() == 1); - device_type = device_types.begin()->str(); } return success(); } @@ -637,7 +553,7 @@ Operation* BuildPartitionedOutputs( builder.create(result_op->getLoc(), results); // Then erase all the identity and partitioned output ops. - for (auto [_, ops] : partitioned_outputs) { + for (const auto& [_, ops] : partitioned_outputs) { for (mlir::TF::TPUPartitionedOutputV2Op op : ops) { op->erase(); } @@ -885,39 +801,14 @@ LogicalResult ReplicateCluster(mlir::tf_device::ClusterOp cluster, return success(); } -void SetNoReplicationClusterAttrs(mlir::tf_device::ClusterOp cluster, - llvm::StringRef device_type, - llvm::StringRef device) { - OpBuilder builder(cluster); - cluster->setAttr(mlir::TF::kReplicationInfoAttr, - builder.getStringAttr(kNoReplicationCluster)); - cluster->setAttr(mlir::TF::kCompileDeviceTypeAttr, - builder.getStringAttr(device_type)); - - if (!device.empty()) { - cluster->setAttr(kDeviceAttr, builder.getStringAttr(device)); - } - // TODO(b/229992058) Propagate `allow_soft_placement` (and other attributes?) - // instead of hard-coding. - cluster->setAttr("allow_soft_placement", builder.getBoolAttr(true)); - cluster->setAttr("topology", builder.getStringAttr("")); - cluster->setAttr("num_cores_per_replica", - builder.getIntegerAttr(builder.getI32Type(), 1)); - cluster->setAttr("device_assignment", builder.getArrayAttr({})); - cluster->setAttr("use_spmd_for_xla_partitioning", builder.getBoolAttr(false)); - cluster->setAttr("step_marker_location", builder.getStringAttr("")); -} - -// Forms compilation clusters in `block`. If the block contains a -// `TPUReplicateMetadata` op, then we form clusters according to -// `_replication_info` values (ops with same value go to same cluster). -// Otherwise, in the non-replicated case, we build one compilation cluster per +// Forms clusters with ops of the same `_replication_info` attribute under a // block. // -// We do this in following steps: -// 1. Find `TPUReplicateMetadata` op in `block` (might not exist). -// 2. Collect and group cluster ops (either based on `_replication_info` -// attributes or forming one single cluster). +// For a given block, clusters are formed via grouping ops by +// `_replication_info` attributes. For every cluster formed: +// 1. Find associated TPUReplicateMetadata attributes with the same +// `_replication_info` attribute. +// 2. Find users not in cluster that are interleaved between cluster ops. // 3. Find external uses of cluster ops. // 4. Create `tf_device.cluster` with results consisting of the external uses // of cluster ops determined at 3. @@ -948,24 +839,22 @@ LogicalResult FormClustersInBlock( return mlir::failure(); } } + return success(); } ClusterMap clusters; - std::string device_type; - std::string device; - result = CollectAndGroupClusterOps(block, &clusters, device_type, device); + result = CollectAndGroupClusterOps(block, &clusters); if (failed(result)) return result; for (const auto& cluster_metadata_and_ops : clusters) { const auto& cluster_ops = cluster_metadata_and_ops.getSecond(); - bool has_replication = - cluster_metadata_and_ops.getFirst() != kNoReplicationCluster; auto cluster_metadata = metadata_map.find(cluster_metadata_and_ops.getFirst()); - // No TPUReplicateMetadata for a `_replication_info` attribute. - if (has_replication && cluster_metadata == metadata_map.end()) { + // llvm::errs() << __func__ << "\n"; + // No TPUReplicateMetadata for a `_replication_info` attribute. + if (cluster_metadata == metadata_map.end()) { block->getParentOp()->emitWarning() << "TPUReplicateMetadata for associated '" << mlir::TF::kReplicationInfoAttr << "' attribute '" @@ -984,27 +873,19 @@ LogicalResult FormClustersInBlock( mlir::tf_device::ClusterOp cluster = CreateClusterOp( block, cluster_ops, results, cluster_successor_ops.getArrayRef()); - if (!has_replication) { - SetNoReplicationClusterAttrs(cluster, device_type, device); - continue; - } - // Determine `num_replicas`. - auto num_replicas_attr = - cluster_metadata->getSecond().get(kNumReplicasAttr); - if (!num_replicas_attr || !num_replicas_attr.isa()) + auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr); + if (!num_replicas || !num_replicas.isa()) return cluster.emitError() << "requires '" << kNumReplicasAttr << "' int attribute"; - int num_replicas = num_replicas_attr.cast().getInt(); - // Determine `num_cores_per_replica`. int num_cores_per_replica = 1; - auto num_cores_per_replica_attr = - cluster_metadata->getSecond() - .get(kNumCoresPerReplicaAttr) - .dyn_cast_or_null(); + auto num_cores_per_replica_attr = mlir::dyn_cast_or_null( + cluster_metadata->getSecond().get(kNumCoresPerReplicaAttr)); if (num_cores_per_replica_attr) num_cores_per_replica = num_cores_per_replica_attr.getInt(); - if (failed(ReplicateCluster(cluster, num_replicas, num_cores_per_replica))) + if (failed(ReplicateCluster(cluster, + num_replicas.cast().getInt(), + num_cores_per_replica))) return mlir::failure(); // Copy TPUReplicateMetadata attributes to `tf_device.cluster`. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc similarity index 88% rename from tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc rename to tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc index fb2588f50631e8..ba35b03e8d6be7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,14 +53,30 @@ limitations under the License. #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" -namespace mlir { -namespace TFTPU { +namespace tensorflow { +namespace tf2xla { +namespace internal { namespace { using OpShardingVariant = std::variant; using OpShardingVector = llvm::SmallVector; using OptionalOpShardingVector = llvm::SmallVector, 8>; +using llvm::StringRef; +using mlir::Block; +using mlir::BlockArgument; +using mlir::BoolAttr; +using mlir::Builder; +using mlir::IntegerAttr; +using mlir::LogicalResult; +using mlir::ModuleOp; +using mlir::Operation; +using mlir::OpOperand; +using mlir::OpResult; +using mlir::RankedTensorType; +using mlir::StringAttr; +using mlir::Value; +using mlir::WalkResult; constexpr char kReplicateSharding[] = ""; constexpr char kShardingAttr[] = "mhlo.sharding"; @@ -69,7 +85,7 @@ constexpr char kAliasingAttr[] = "tf.aliasing_output"; constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica"; #define GEN_PASS_DEF_TPUSHARDINGIDENTIFICATIONPASS -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" struct TPUShardingIdentificationPass : public impl::TPUShardingIdentificationPassBase< @@ -93,11 +109,11 @@ mlir::Operation* NullUnlessSharded(PartitionedOp op) { // a `tf_device.cluster_func`. mlir::Operation* GetXlaShardingFromOperand(Value value) { Value value_to_visit = value; - if (auto read_var = value_to_visit.getDefiningOp()) + if (auto read_var = value_to_visit.getDefiningOp()) value_to_visit = read_var.getResource(); if (auto partitioned_input = - value_to_visit.getDefiningOp()) { + value_to_visit.getDefiningOp()) { return NullUnlessSharded(partitioned_input); } @@ -107,10 +123,10 @@ mlir::Operation* GetXlaShardingFromOperand(Value value) { // Returns the op sharding attribute from a partitioned operator. std::optional GetXlaShardingFromOperator(mlir::Operation* op) { if (auto partitioned_output = - llvm::dyn_cast(op)) { + llvm::dyn_cast(op)) { return partitioned_output.get_XlaSharding(); } else if (auto partitioned_input = - llvm::dyn_cast(op)) { + llvm::dyn_cast(op)) { return partitioned_input.get_XlaSharding(); } else { return std::nullopt; @@ -174,9 +190,9 @@ LogicalResult VerifySharding(mlir::Type type, // Some test cases use \01\02\03 as sharding, to test propagation. Treat // a non-proto sharding as valid, and don't verify further. We also only // verify shardings that actually break a tensor apart. - return success(); + return mlir::success(); } - if (RankedTensorType ranked_type = type.dyn_cast()) { + if (RankedTensorType ranked_type = mlir::dyn_cast(type)) { const int64_t tensor_rank = ranked_type.getRank(); int tile_assignment_rank = sharding->tile_assignment_dimensions_size(); @@ -194,10 +210,10 @@ LogicalResult VerifySharding(mlir::Type type, << " extra dimension(s) by: " << sharding->DebugString(); } - return failure(); + return mlir::failure(); } } - return success(); + return mlir::success(); } // Verify sharding for all arguments and return values. @@ -209,7 +225,7 @@ LogicalResult VerifyShardings(mlir::func::FuncOp func, llvm::zip(sharding_for_args, function_block.getArguments())) { const auto& sharding = std::get<0>(sharding_and_arg); BlockArgument arg = std::get<1>(sharding_and_arg); - if (failed(VerifySharding(arg.getType(), sharding))) return failure(); + if (failed(VerifySharding(arg.getType(), sharding))) return mlir::failure(); } Operation* terminator = function_block.getTerminator(); for (auto sharding_and_retval : @@ -217,9 +233,9 @@ LogicalResult VerifyShardings(mlir::func::FuncOp func, const auto& sharding = std::get<0>(sharding_and_retval); OpOperand& retval = std::get<1>(sharding_and_retval); if (failed(VerifySharding(retval.get().getType(), sharding))) - return failure(); + return mlir::failure(); } - return success(); + return mlir::success(); } // Assign the logical device if an op has an attribute `TPU_REPLICATED_CORE:n`, @@ -262,7 +278,7 @@ std::optional GetXlaShardingFromArg( for (auto& use : value_to_visit.getUses()) { Operation* owner = use.getOwner(); - if (auto sharding = llvm::dyn_cast(owner)) + if (auto sharding = llvm::dyn_cast(owner)) return sharding.get_XlaSharding(); if (auto logical_device = AssignLogicalDeviceFromTPUReplicatedCoreAttr( @@ -270,7 +286,7 @@ std::optional GetXlaShardingFromArg( return logical_device; } - if (auto while_op = llvm::dyn_cast(owner)) { + if (auto while_op = llvm::dyn_cast(owner)) { const int operand_number = use.getOperandNumber(); next_values_to_visit.push_back( while_op.getCond().front().getArgument(operand_number)); @@ -279,14 +295,15 @@ std::optional GetXlaShardingFromArg( continue; } - if (llvm::isa(owner)) { + if (llvm::isa(owner)) { next_values_to_visit.push_back(use.getOwner()->getResult(0)); continue; } - if (auto call_op = llvm::dyn_cast(owner)) { - func::FuncOp func = - llvm::dyn_cast(call_op.resolveCallable()); + if (auto call_op = llvm::dyn_cast(owner)) { + mlir::func::FuncOp func = + llvm::dyn_cast(call_op.resolveCallable()); if (!func) continue; next_values_to_visit.push_back( func.getArgument(use.getOperandNumber())); @@ -307,8 +324,8 @@ std::optional GetXlaShardingFromArg( // XlaSharding op. void IdentifyXlaShardingForComputationInputs( const llvm::SmallVector& logical_device_vec, - bool infer_from_computation, tf_device::ClusterFuncOp cluster_func, - func::FuncOp func, Builder* builder, + bool infer_from_computation, mlir::tf_device::ClusterFuncOp cluster_func, + mlir::func::FuncOp func, Builder* builder, OptionalOpShardingVector& sharding_for_args) { // Look up function definition from module. Block& function_block = func.front(); @@ -360,20 +377,20 @@ mlir::Operation* GetXlaShardingFromResult(Value value) { Operation* user = *value.getUsers().begin(); if (auto partitioned_output = - llvm::dyn_cast(user)) + llvm::dyn_cast(user)) return NullUnlessSharded(partitioned_output); - if (auto assign_var = llvm::dyn_cast(user)) + if (auto assign_var = llvm::dyn_cast(user)) if (auto partitioned_input = assign_var.getResource() - .getDefiningOp()) + .getDefiningOp()) return NullUnlessSharded(partitioned_input); return nullptr; } absl::Status DetermineShardingFromAlias( - func::FuncOp func, OptionalOpShardingVector& input_shardings, + mlir::func::FuncOp func, OptionalOpShardingVector& input_shardings, OptionalOpShardingVector& output_shardings) { for (int arg_idx = 0; arg_idx < func.getNumArguments(); ++arg_idx) { if (auto v = @@ -427,7 +444,7 @@ std::optional GetXlaShardingFromRetval( continue; } - if (auto sharding = llvm::dyn_cast_or_null(def)) + if (auto sharding = llvm::dyn_cast_or_null(def)) return sharding.get_XlaSharding(); if (auto sharding = def->getAttrOfType("_XlaSharding")) { @@ -456,20 +473,20 @@ std::optional GetXlaShardingFromRetval( continue; } - if (auto call_op = llvm::dyn_cast_or_null(def)) { - func::FuncOp func = - llvm::dyn_cast(call_op.resolveCallable()); + if (auto call_op = llvm::dyn_cast_or_null(def)) { + mlir::func::FuncOp func = + llvm::dyn_cast(call_op.resolveCallable()); if (!func) continue; value_to_visit = func.front().getTerminator()->getOperand( - value_to_visit.cast().getResultNumber()); + mlir::cast(value_to_visit).getResultNumber()); values_to_visit.push_back(value_to_visit); continue; } - if (auto while_op = llvm::dyn_cast(def)) { - if (auto op_result = value_to_visit.cast()) { + if (auto while_op = llvm::dyn_cast(def)) { + if (auto op_result = mlir::cast(value_to_visit)) { int result_idx = op_result.getResultNumber(); - if (auto yield_op = llvm::dyn_cast( + if (auto yield_op = llvm::dyn_cast( while_op.getBody().front().getTerminator())) { values_to_visit.push_back(yield_op.getOperand(result_idx)); } @@ -485,8 +502,8 @@ std::optional GetXlaShardingFromRetval( // XlaSharding/ TPUPartitionedOutput op connected to the retvals/results. void IdentifyXlaShardingForComputationOutputs( const llvm::SmallVector& logical_device_vec, - bool infer_from_computation, tf_device::ClusterFuncOp cluster_func, - func::FuncOp func, Builder* builder, + bool infer_from_computation, mlir::tf_device::ClusterFuncOp cluster_func, + mlir::func::FuncOp func, Builder* builder, OptionalOpShardingVector& sharding_for_rets) { Block& function_block = func.front(); Operation* terminator = function_block.getTerminator(); @@ -566,8 +583,8 @@ absl::Status MoveSharding(OptionalOpShardingVector& optional_shardings, // depending on `use_spmd`. absl::Status IdentifyXlaShardingForInputsAndOutputs( const llvm::SmallVector& logical_device_vec, bool use_spmd, - bool infer_from_computation, tf_device::ClusterFuncOp cluster_func, - func::FuncOp func, Builder* builder, OpShardingVector& input_sharding, + bool infer_from_computation, mlir::tf_device::ClusterFuncOp cluster_func, + mlir::func::FuncOp func, Builder* builder, OpShardingVector& input_sharding, OpShardingVector& output_sharding) { OptionalOpShardingVector optional_input_sharding; OptionalOpShardingVector optional_output_sharding; @@ -592,11 +609,11 @@ absl::Status IdentifyXlaShardingForInputsAndOutputs( // Extracts input/output sharding configuration of `cluster_func` by parsing // XlaSharding ops inside the `cluster_func`. LogicalResult IdentifyXlaShardingForTPUComputation( - Builder* builder, tf_device::ClusterFuncOp cluster_func) { + Builder* builder, mlir::tf_device::ClusterFuncOp cluster_func) { // Look up function definition from module. - func::FuncOp func = - cluster_func->getParentOfType().lookupSymbol( - cluster_func.getFunc()); + mlir::func::FuncOp func = + cluster_func->getParentOfType() + .lookupSymbol(cluster_func.getFunc()); bool use_spmd = false; if (auto use_spmd_attr = cluster_func->getAttrOfType(kUseSpmdAttr)) @@ -624,7 +641,7 @@ LogicalResult IdentifyXlaShardingForTPUComputation( sharding_for_args, sharding_for_rets); !status.ok()) { LOG(ERROR) << status; - return failure(); + return mlir::failure(); }; auto has_maximal_sharding = @@ -654,7 +671,7 @@ LogicalResult IdentifyXlaShardingForTPUComputation( sharding_for_args, sharding_for_rets); !status.ok()) { LOG(ERROR) << status; - return failure(); + return mlir::failure(); } } @@ -686,26 +703,30 @@ LogicalResult IdentifyXlaShardingForTPUComputation( GetStrArrayAttr(builder, sharding_for_args)); cluster_func->setAttr(tensorflow::kOutputShardingAttr, GetStrArrayAttr(builder, sharding_for_rets)); - return success(); + return mlir::success(); } void TPUShardingIdentificationPass::runOnOperation() { Builder builder(getOperation().getContext()); - auto result = getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) { - if (failed(IdentifyXlaShardingForTPUComputation(&builder, cluster_func))) { - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); + auto result = + getOperation().walk([&](mlir::tf_device::ClusterFuncOp cluster_func) { + if (failed( + IdentifyXlaShardingForTPUComputation(&builder, cluster_func))) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); if (result.wasInterrupted()) return signalPassFailure(); } } // namespace -std::unique_ptr> CreateTPUShardingIdentificationPass() { +std::unique_ptr> +CreateTPUShardingIdentificationPass() { return std::make_unique(); } -} // namespace TFTPU -} // namespace mlir +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/tests/hlo_xla_runtime_pipeline.mlir b/tensorflow/compiler/mlir/tf2xla/tests/hlo_xla_runtime_pipeline.mlir deleted file mode 100644 index 8b0dc1b54bf9e5..00000000000000 --- a/tensorflow/compiler/mlir/tf2xla/tests/hlo_xla_runtime_pipeline.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: tf-opt -split-input-file -hlo-xla-runtime-pipeline %s | FileCheck %s - -// CHECK-LABEL: func.func @simple_add( -func.func @simple_add(%arg0: tensor) -> tensor { - // CHECK: arith.addf - %0 = mhlo.add %arg0, %arg0 : tensor - return %0 : tensor -} diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir index 328a00ce59bbec..b015011dae7b2d 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir @@ -407,9 +407,9 @@ func.func @uniform_quantized_clip_by_value(%input: tensor<3x2xf32>) -> tensor<3x %zps = "tf.Const"() { value = dense<4> : tensor<2xi32> } : () -> tensor<2xi32> // tensor_proto that points to dense<127> of type !tf_type.qint32. - // CHECK-DAG: %[[MIN_MAX:.*]] = mhlo.constant() <{value = dense<127> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> - %min = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> - %max = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> + // CHECK-DAG: %[[MIN_MAX:.*]] = mhlo.constant() <{value = dense<127> : tensor<3x2xi32>}> : () -> tensor<3x2x!quant.uniform> + %min = "tf.Const"() { value = #tf_type : tensor<3x2x!tf_type.qint32> } : () -> tensor<3x2x!tf_type.qint32> + %max = "tf.Const"() { value = #tf_type : tensor<3x2x!tf_type.qint32> } : () -> tensor<3x2x!tf_type.qint32> // CHECK-DAG: %[[OPERAND:.*]] = mhlo.uniform_quantize %arg0 : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> %0 = "tf.UniformQuantize"(%input, %scales, %zps) { @@ -419,10 +419,10 @@ func.func @uniform_quantized_clip_by_value(%input: tensor<3x2xf32>) -> tensor<3x // CHECK-DAG: %[[CONVERT_1:.*]] = mhlo.bitcast_convert %[[OPERAND]] : (tensor<3x2x!quant.uniform>) -> tensor<3x2xi32> // CHECK-DAG: %[[CONVERT_2:.*]] = mhlo.bitcast_convert %[[CONVERT_1]] : (tensor<3x2xi32>) -> tensor<3x2x!quant.uniform> // CHECK: %[[MIN_CLIPPED:.*]] = chlo.broadcast_maximum %[[CONVERT_2]], %[[MIN_MAX]] {broadcast_dimensions = array} : - // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) + // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<3x2x!quant.uniform>) // CHECK-SAME: -> tensor<3x2x!quant.uniform> // CHECK: %[[MAX_CLIPPED:.*]] = chlo.broadcast_minimum %[[MIN_CLIPPED]], %[[MIN_MAX]] {broadcast_dimensions = array} : - // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) + // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<3x2x!quant.uniform>) // CHECK-SAME: -> tensor<3x2x!quant.uniform> // CHECK: %[[RESULT:.*]] = mhlo.bitcast_convert %[[MAX_CLIPPED]] : (tensor<3x2x!quant.uniform>) -> tensor<3x2xi32> // CHECK: return %[[RESULT]] : tensor<3x2xi32> @@ -430,48 +430,48 @@ func.func @uniform_quantized_clip_by_value(%input: tensor<3x2xf32>) -> tensor<3x quantization_axis = 1 : i64, quantization_min_val = -2147483648 : i64, quantization_max_val = 2147483647 : i64 - } : (tensor<3x2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, tensor<2xf32>, tensor<2xi32>) -> tensor<3x2x!tf_type.qint32> + } : (tensor<3x2x!tf_type.qint32>, tensor<3x2x!tf_type.qint32>, tensor<3x2x!tf_type.qint32>, tensor<2xf32>, tensor<2xi32>) -> tensor<3x2x!tf_type.qint32> func.return %1 : tensor<3x2x!tf_type.qint32> } // ----- // CHECK-LABEL: func @uniform_quantized_clip_by_value_min_not_const -func.func @uniform_quantized_clip_by_value_min_not_const(%input: tensor<3x2x!tf_type.qint32>, %min: tensor<2x!tf_type.qint32>) -> tensor<3x2x!tf_type.qint32> { +func.func @uniform_quantized_clip_by_value_min_not_const(%input: tensor<3x2x!tf_type.qint32>, %min: tensor<3x2x!tf_type.qint32>) -> tensor<3x2x!tf_type.qint32> { %scales = "tf.Const"() { value = dense<2.0> : tensor<2xf32> } : () -> tensor<2xf32> %zps = "tf.Const"() { value = dense<4> : tensor<2xi32> } : () -> tensor<2xi32> // tensor_proto that points to dense<127> of type !tf_type.qint32. - %max = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> + %max = "tf.Const"() { value = #tf_type : tensor<3x2x!tf_type.qint32> } : () -> tensor<3x2x!tf_type.qint32> // CHECK-DAG: %[[INPUT:.*]] = mhlo.bitcast_convert %arg0 : (tensor<3x2xi32>) -> tensor<3x2x!quant.uniform> - // CHECK-DAG: %[[MIN:.*]] = mhlo.bitcast_convert %arg1 : (tensor<2xi32>) -> tensor<2x!quant.uniform> + // CHECK-DAG: %[[MIN:.*]] = mhlo.bitcast_convert %arg1 : (tensor<3x2xi32>) -> tensor<3x2x!quant.uniform> // CHECK: chlo.broadcast_maximum %[[INPUT]], %[[MIN]] %res = "tf.UniformQuantizedClipByValue"(%input, %min, %max, %scales, %zps) { quantization_axis = 1 : i64, quantization_min_val = -2147483648 : i64, quantization_max_val = 2147483647 : i64 - } : (tensor<3x2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, tensor<2xf32>, tensor<2xi32>) -> tensor<3x2x!tf_type.qint32> + } : (tensor<3x2x!tf_type.qint32>, tensor<3x2x!tf_type.qint32>, tensor<3x2x!tf_type.qint32>, tensor<2xf32>, tensor<2xi32>) -> tensor<3x2x!tf_type.qint32> func.return %res : tensor<3x2x!tf_type.qint32> } // ----- // CHECK-LABEL: func @uniform_quantized_clip_by_value_max_not_const -func.func @uniform_quantized_clip_by_value_max_not_const(%input: tensor<3x2x!tf_type.qint32>, %max: tensor<2x!tf_type.qint32>) -> tensor<3x2x!tf_type.qint32> { +func.func @uniform_quantized_clip_by_value_max_not_const(%input: tensor<3x2x!tf_type.qint32>, %max: tensor<3x2x!tf_type.qint32>) -> tensor<3x2x!tf_type.qint32> { %scales = "tf.Const"() { value = dense<2.0> : tensor<2xf32> } : () -> tensor<2xf32> %zps = "tf.Const"() { value = dense<4> : tensor<2xi32> } : () -> tensor<2xi32> // tensor_proto that points to dense<127> of type !tf_type.qint32. - %min = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> + %min = "tf.Const"() { value = #tf_type : tensor<3x2x!tf_type.qint32> } : () -> tensor<3x2x!tf_type.qint32> // CHECK-DAG: %[[INPUT:.*]] = mhlo.bitcast_convert %arg0 : (tensor<3x2xi32>) -> tensor<3x2x!quant.uniform> - // CHECK-DAG: %[[MAX:.*]] = mhlo.bitcast_convert %arg1 : (tensor<2xi32>) -> tensor<2x!quant.uniform> + // CHECK-DAG: %[[MAX:.*]] = mhlo.bitcast_convert %arg1 : (tensor<3x2xi32>) -> tensor<3x2x!quant.uniform> // CHECK-DAG: %[[INPUT_1:.*]] = chlo.broadcast_maximum // CHECK: chlo.broadcast_minimum %[[INPUT_1]], %[[MAX]] %res = "tf.UniformQuantizedClipByValue"(%input, %min, %max, %scales, %zps) { quantization_axis = 1 : i64, quantization_min_val = -2147483648 : i64, quantization_max_val = 2147483647 : i64 - } : (tensor<3x2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, tensor<2xf32>, tensor<2xi32>) -> tensor<3x2x!tf_type.qint32> + } : (tensor<3x2x!tf_type.qint32>, tensor<3x2x!tf_type.qint32>, tensor<3x2x!tf_type.qint32>, tensor<2xf32>, tensor<2xi32>) -> tensor<3x2x!tf_type.qint32> func.return %res : tensor<3x2x!tf_type.qint32> } diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index bb9ca266fc7abc..91008b91056d40 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -2637,35 +2637,6 @@ func.func @bitcast_smaller_output_width(%arg0: tensor<2xf32>) -> tensor<2x2xf16> // ----- -// CHECK-LABEL: reshape -func.func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<2x1xf32> { - // CHECK: mhlo.reshape - %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<2x1xf32> - func.return %0 : tensor<2x1xf32> -} - -// ----- - -// CHECK-LABEL: not_lowering_reshape -func.func @not_lowering_reshape(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor<1x!tf_type.string> { - // CHECK: "tf.Reshape" - %0 = "tf.Reshape"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor<1x!tf_type.string> - func.return %0 : tensor<1x!tf_type.string> -} - -// ----- - -// CHECK-LABEL: reshape_dynamic -func.func @reshape_dynamic(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor { - // CHECK: "chlo.dynamic_reshape" - // CHLO: mhlo.compute_reshape_shape - // CHLO: mhlo.dynamic_reshape - %0 = "tf.Reshape"(%arg0, %arg1) : (tensor, tensor<2xi32>) -> tensor - func.return %0 : tensor -} - -// ----- - // CHECK-LABEL: squeeze func.func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> { // CHECK: mhlo.reshape @@ -2680,7 +2651,7 @@ func.func @squeeze_ranked(%arg0: tensor) -> tensor { // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[D2:.*]] = tensor.dim %arg0, %[[C2]] : tensor // CHECK: %[[T:.*]] = tensor.from_elements %[[D2]] : tensor<1xindex> - // CHECK: %[[R:.*]] = "chlo.dynamic_reshape"(%arg0, %[[T]]) : (tensor, tensor<1xindex>) -> tensor + // CHECK: %[[R:.*]] = mhlo.dynamic_reshape %arg0, %[[T]] : (tensor, tensor<1xindex>) -> tensor // CHECK: return %[[R]] : tensor %0 = "tf.Squeeze"(%arg0) { squeeze_dims = [0, 1] }: (tensor) -> tensor func.return %0 : tensor @@ -2695,7 +2666,7 @@ func.func @squeeze_ranked_negative(%arg0: tensor) -> tensor // CHECK: %[[T:.*]] = tensor.from_elements %[[D0]], %[[D2]] : tensor<2xindex> - // CHECK: %[[R:.*]] = "chlo.dynamic_reshape"(%arg0, %[[T]]) : (tensor, tensor<2xindex>) -> tensor + // CHECK: %[[R:.*]] = mhlo.dynamic_reshape %arg0, %[[T]] : (tensor, tensor<2xindex>) -> tensor // CHECK: return %[[R]] : tensor %0 = "tf.Squeeze"(%arg0) { squeeze_dims = [-2] }: (tensor) -> tensor func.return %0 : tensor diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index b76b52c9fd774a..34ffdfa90f028f 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -477,6 +477,7 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tpu_embedding_ops_registry", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc index b5a99b35f7b547..d5e5c5d08e4ff3 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.h" +#include "llvm/ADT/DenseSet.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h" @@ -135,232 +136,243 @@ bool IsOpTypeAllowedTf2XlaFallback(const TypeID& type_id) { // end, which would not be thread safe. static auto* ops = [] { - llvm::SmallDenseSet* ops_set = - new llvm::SmallDenseSet{ - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - // CaseOp isn't actually supported but is enabled for testing to - // make sure ops with symbol ref attributes are filtered out. - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - // TODO(hinsu): Canonicalize QuantizeAndDequantize and - // QuantizeAndDequantizeV2 to QuantizeAndDequantizeV3 by converting - // attributes to operands. - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get< - TF::XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInputOp>(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - }; + llvm::SmallDenseSet* ops_set = new llvm::SmallDenseSet< + mlir::TypeID, 512>{ + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + // CaseOp isn't actually supported but is enabled for testing to + // make sure ops with symbol ref attributes are filtered out. + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + // TODO(hinsu): Canonicalize QuantizeAndDequantize and + // QuantizeAndDequantizeV2 to QuantizeAndDequantizeV3 by converting + // attributes to operands. + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get< + TF::XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInputOp>(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get< + TF::XlaSparseDenseMatmulGradWithAdagradAndStaticBufferSizeOp>(), + TypeID::get< + TF::XlaSparseDenseMatmulGradWithAdagradMomentumAndStaticBufferSizeOp>(), // NOLINT + TypeID::get< + TF::XlaSparseDenseMatmulGradWithAdamAndStaticBufferSizeOp>(), + TypeID::get< + TF::XlaSparseDenseMatmulGradWithFtrlAndStaticBufferSizeOp>(), + TypeID::get< + TF::XlaSparseDenseMatmulGradWithSgdAndStaticBufferSizeOp>(), // NOLINT + TypeID::get(), + TypeID::get(), + TypeID::get(), + }; // Add the ops from the TPUEmbeddingOpsRegistry. for (auto op_type_id : diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index bd8b2135882bb2..25b8196ebfe629 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -54,12 +54,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr class LegalizationOpConfigTest : public ::testing::Test { public: - tsl::Status CreateMlirModule(std::string module_string = kMlirModuleStr) { + absl::Status CreateMlirModule(std::string module_string = kMlirModuleStr) { TF_ASSIGN_OR_RETURN( module_, test::GetMlirModuleFromString(module_string, &context_)); context_.loadAllAvailableDialects(); - return tsl::OkStatus(); + return absl::OkStatus(); } absl::StatusOr GetMain() { @@ -135,8 +135,8 @@ TEST_F(LegalizationOpConfigTest, CountLoweringsSet) { // from MLIR to TF2XLA), these numbers should change. Or if TF Dialect adds // a new op, we should expect these to change too. EXPECT_EQ(mlir_lowering_count, 67); - EXPECT_EQ(tf2xla_fallback_count, 316); - EXPECT_EQ(non_categorized_count, 424); + EXPECT_EQ(tf2xla_fallback_count, 322); + EXPECT_EQ(non_categorized_count, 428); } // Just a counter test to see which ops have duplicate lowerings. This isn't a diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index edf0b96b569fea..fd0b33c20c7127 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -51,6 +51,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -89,10 +90,10 @@ static size_t GetFeatureDimension(tensorflow::TensorFormat format, // Gets all integer values from the given attribute and push them to `values`. void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl *values) { - auto array_attr = attr.cast(); + auto array_attr = mlir::cast(attr); values->reserve(array_attr.getValue().size()); for (Attribute val : array_attr.getValue()) - values->push_back(val.cast().getValue().getSExtValue()); + values->push_back(mlir::cast(val).getValue().getSExtValue()); } // Returns 1D 32-bit dense elements attribute with the given values. @@ -142,8 +143,8 @@ Type GetSumAccumulationType(Type input_type) { // format supports negative indexing unlike HLO. static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank, Builder *b) { - IntegerAttr intAttr = attr.dyn_cast_or_null(); - if (auto elementAttr = attr.dyn_cast_or_null()) { + IntegerAttr intAttr = mlir::dyn_cast_or_null(attr); + if (auto elementAttr = mlir::dyn_cast_or_null(attr)) { SmallVector index(elementAttr.getShapedType().getRank(), 0); intAttr = elementAttr.getValues()[index]; } @@ -198,7 +199,7 @@ static ConvertOp CastValueToI64(Location loc, Value value, // must be a ranked tensor. static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value, PatternRewriter *rewriter) { - auto indices_type = value.getType().cast(); + auto indices_type = mlir::cast(value.getType()); int num_outputs = indices_type.getShape().front(); SmallVector unpacked_indices_type( num_outputs, @@ -214,7 +215,7 @@ static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value, // // Aborts if the type is ranked but doesn't have the dimension. int64_t GetDimSize(Type ty, int64_t index) { - RankedTensorType ranked_ty = ty.dyn_cast(); + RankedTensorType ranked_ty = mlir::dyn_cast(ty); if (!ranked_ty) return -1; return ranked_ty.getDimSize(index); @@ -298,8 +299,8 @@ template static Value StaticBinaryBroadcast(Location loc, Value x, Value y, DenseIntElementsAttr broadcast_dims, OpBuilder &builder) { - auto x_type = x.getType().cast(); - auto y_type = y.getType().cast(); + auto x_type = mlir::cast(x.getType()); + auto y_type = mlir::cast(y.getType()); auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims); if (!result_type) { emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type @@ -353,7 +354,7 @@ static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to, Value broadcast_from, int64_t feature_dim, OpBuilder &builder) { auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder); - auto to_type = broadcast_to.getType().cast(); + auto to_type = mlir::cast(broadcast_to.getType()); auto result_shape = builder.create(loc, broadcast_to); auto result_extents_type = GetExtentsTensorTypeFor(to_type); auto result_extents = builder.create( @@ -372,11 +373,11 @@ static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to, static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to, OpBuilder &builder) { auto result_shape = builder.create(loc, broadcast_to); - auto to_type = broadcast_to.getType().cast(); + auto to_type = mlir::cast(broadcast_to.getType()); auto result_extents_type = GetExtentsTensorTypeFor(to_type); auto result_extents = builder.create( loc, result_extents_type, result_shape); - int64_t rank = input.getType().cast().getRank(); + int64_t rank = mlir::cast(input.getType()).getRank(); auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder); return builder.create( loc, to_type, input, result_extents, broadcast_dims); @@ -520,8 +521,8 @@ static void CreateWhile32(Location loc, int num_iterations, static IntegerAttr getFeatureDimensionAttr(Builder &b, tensorflow::TensorFormat format, Value input) { - return b.getI64IntegerAttr( - GetFeatureDimension(format, input.getType().cast())); + return b.getI64IntegerAttr(GetFeatureDimension( + format, mlir::cast(input.getType()))); } //===----------------------------------------------------------------------===// @@ -567,7 +568,7 @@ static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) { // attribute. static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( ElementsAttr input, int column) { - auto int_attr = input.cast(); + auto int_attr = mlir::cast(input); auto shaped_type = int_attr.getType(); auto shape = shaped_type.getShape(); @@ -605,8 +606,8 @@ static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) { // must be broadcasted with a size 1 tensor or another dynamic dimension. // Returns false on rankless. static bool AreBroadcastCompatible(Value x, Value y) { - auto x_rankless = x.getType().dyn_cast(); - auto y_rankless = y.getType().dyn_cast(); + auto x_rankless = mlir::dyn_cast(x.getType()); + auto y_rankless = mlir::dyn_cast(y.getType()); if (!x_rankless || !y_rankless) { return false; } @@ -634,7 +635,7 @@ static bool AreBroadcastCompatible(Value x, Value y) { // updated element type. static Type ChangeTensorElementType(Builder *b, Type tensor_type, Type element_type) { - RankedTensorType ranked_type = tensor_type.dyn_cast(); + RankedTensorType ranked_type = mlir::dyn_cast(tensor_type); if (ranked_type) { return tensorflow::GetTypeFromTFTensorShape(ranked_type.getShape(), element_type); @@ -659,7 +660,7 @@ static Type GetAccumulationType(Type ty) { //===----------------------------------------------------------------------===// static DenseElementsAttr GetEpsilonValue(Type ty) { - auto element_ty = ty.cast().getElementType(); + auto element_ty = mlir::cast(ty).getElementType(); auto scalar_ty = tensorflow::GetTypeFromTFTensorShape({}, element_ty); if (element_ty.isF16()) { uint16_t raw_epsilon = Eigen::numext::bit_cast( @@ -750,9 +751,10 @@ static bool ArgTypesMatchCallee(mlir::Operation *op, OperandRange args, static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices, DenseIntElementsAttr slice_sizes) { - auto input_ty = input.getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(input.getType()); if (!input_ty) return false; - auto start_indices_ty = start_indices.getType().dyn_cast(); + auto start_indices_ty = + mlir::dyn_cast(start_indices.getType()); if (!start_indices_ty) return false; int64_t input_rank = input_ty.getRank(); @@ -780,11 +782,11 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( Builder *builder) { DenseIntElementsAttr constant_start_indices; if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) { - return hlo::convertElementsAttr(slice_sizes, builder->getIntegerType(64)) - .cast(); + return mlir::cast( + hlo::convertElementsAttr(slice_sizes, builder->getIntegerType(64))); } - auto input_ty = input.getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(input.getType()); int64_t input_rank = input_ty.getRank(); ArrayRef input_shape = input_ty.getShape(); SmallVector normalized_sizes; @@ -906,7 +908,7 @@ class ConvertBiasAddOp : public OpRewritePattern { if (!FormatFromString(op.getDataFormat().str(), &data_format)) return op.emitOpError("invalid data format"); - auto value_type = op.getValue().getType().dyn_cast(); + auto value_type = mlir::dyn_cast(op.getValue().getType()); if (!value_type) return failure(); auto feature_dim = GetFeatureDimension(data_format, value_type); auto bias_broadcast = Broadcast1DToFeatureDim( @@ -1008,11 +1010,9 @@ class ConvertConvDynamic : public OpRewritePattern { if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) return failure(); - auto input_ty = - op.getInput().getType().template dyn_cast(); - auto filter_ty = - op.getFilter().getType().template dyn_cast(); - auto result_ty = op.getType().template dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); + auto result_ty = mlir::dyn_cast(op.getType()); if (!input_ty || !filter_ty || !result_ty) return failure(); // TODO(disc): Remove this constraint once fold and canonicalization // implemented. @@ -1035,7 +1035,7 @@ class ConvertConvDynamic : public OpRewritePattern { SmallVector paddings; auto get_int = [](Attribute attr) { - return attr.template cast().getInt(); + return mlir::cast(attr).getInt(); }; constexpr int num_dims = num_spatial_dims + 2; @@ -1177,10 +1177,8 @@ class ConvertConvOp : public OpRewritePattern { if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) return failure(); - auto input_ty = - op.getInput().getType().template dyn_cast(); - auto filter_ty = - op.getFilter().getType().template dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); // With the exception of input's batch dimension, input and filter need to // have static shape for calculation of HLO paddings and feature group count @@ -1205,7 +1203,7 @@ class ConvertConvOp : public OpRewritePattern { SmallVector paddings; auto get_int = [](Attribute attr) { - return attr.template cast().getInt(); + return mlir::cast(attr).getInt(); }; constexpr int num_dims = num_spatial_dims + 2; @@ -1228,7 +1226,7 @@ class ConvertConvOp : public OpRewritePattern { int64_t pad_high_int64; int64_t input_size = input_ty.getDimSize(dim); if (input_size == ShapedType::kDynamic) return failure(); - tsl::Status status = tensorflow::GetWindowedOutputSizeVerbose( + absl::Status status = tensorflow::GetWindowedOutputSizeVerbose( input_size, filter_ty.getDimSize(i), dilation, stride, padding, &output_size, &pad_low_int64, &pad_high_int64); if (!status.ok()) return failure(); @@ -1318,8 +1316,8 @@ class ConvertPadOpDynamic : public OpRewritePattern { auto input = op.getInput(); auto paddings = op.getPaddings(); auto constant_values = op.getConstantValues(); - auto input_type = input.getType().dyn_cast(); - auto paddings_type = paddings.getType().dyn_cast(); + auto input_type = mlir::dyn_cast(input.getType()); + auto paddings_type = mlir::dyn_cast(paddings.getType()); if (!input_type || !paddings_type || !paddings_type.hasStaticShape()) return failure(); @@ -1385,9 +1383,9 @@ class ConvertGatherNdOpDynamic : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto params = op.getParams(); - auto params_ty = params.getType().dyn_cast(); + auto params_ty = mlir::dyn_cast(params.getType()); auto indices = op.getIndices(); - auto indices_ty = indices.getType().dyn_cast(); + auto indices_ty = mlir::dyn_cast(indices.getType()); auto params_rank = params_ty.getRank(); auto indices_rank = indices_ty.getRank(); int64_t num_index_dims = indices_ty.getDimSize(indices_rank - 1); @@ -1485,8 +1483,8 @@ class ConvertBF16FloorDivOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::FloorDivOp op, PatternRewriter &rewriter) const override { - auto l = op.getX().dyn_cast>(); - auto r = op.getY().dyn_cast>(); + auto l = mlir::dyn_cast>(op.getX()); + auto r = mlir::dyn_cast>(op.getY()); if (!l || !r) return failure(); auto element_type = getElementTypeOrSelf(l.getType()); @@ -1515,14 +1513,14 @@ class ConvertBroadcastToOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::BroadcastToOp op, PatternRewriter &rewriter) const override { - auto input_type = op.getInput().getType().dyn_cast(); + auto input_type = mlir::dyn_cast(op.getInput().getType()); auto output_type = op.getOutput().getType(); if (!input_type) { return rewriter.notifyMatchFailure(op, "requires ranked input shape"); } llvm::SmallVector broadcast_dimensions; if (input_type.getRank() > 0) { - auto ranked_output_type = output_type.dyn_cast(); + auto ranked_output_type = mlir::dyn_cast(output_type); if (!ranked_output_type) { return rewriter.notifyMatchFailure(op, "requires ranked output shape"); } @@ -1546,7 +1544,7 @@ class ConvertRollOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TF::RollOp op, PatternRewriter &rewriter) const override { - auto shift_ty = op.getShift().getType().dyn_cast(); + auto shift_ty = mlir::dyn_cast(op.getShift().getType()); if (!shift_ty || shift_ty.getRank() != 0) { return rewriter.notifyMatchFailure( op, "require the type of shift to be 0D tensor"); @@ -1558,7 +1556,7 @@ class ConvertRollOp : public OpRewritePattern { } int axis = val.getSExtValue(); - auto input_ty = op.getInput().getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); if (!input_ty || !input_ty.hasStaticShape()) { return rewriter.notifyMatchFailure( op, "require the type of input to have static shapes"); @@ -1674,7 +1672,7 @@ class ConvertDiagPartOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::DiagPartOp op, PatternRewriter &rewriter) const override { - auto input_type = op.getInput().getType().dyn_cast(); + auto input_type = mlir::dyn_cast(op.getInput().getType()); if (!input_type || !input_type.hasStaticShape()) return failure(); int64_t num_dims = input_type.getRank(); if (num_dims < 2 || num_dims % 2 != 0) return failure(); @@ -1771,7 +1769,7 @@ class ConvertMatrixDiagPartV3Op LogicalResult matchAndRewrite(TF::MatrixDiagPartV3Op op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - ShapedType input_type = op.getInput().getType().dyn_cast(); + ShapedType input_type = mlir::dyn_cast(op.getInput().getType()); // Align is a string specifying how superdiagonals and subdiagonals should // be aligned/padded for diagonals that are shorter than max_diag_len. The @@ -2035,7 +2033,7 @@ class ConvertFFTOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - auto input_ty = op.getInput().getType().template cast(); + auto input_ty = mlir::cast(op.getInput().getType()); if (!input_ty.hasRank()) { return failure(); } @@ -2131,14 +2129,12 @@ class ConvertFusedBatchNormGradBase // TODO(b/141785544): Update this to not require static shapes. // activation shape needs to be static to convert negative indices in // TensorFlow to absolute indices required by HLO. - RankedTensorType act_type = - act.getType().template dyn_cast(); + RankedTensorType act_type = mlir::dyn_cast(act.getType()); if (!act_type) return failure(); Type act_ele_type = act_type.getElementType(); // To support mixed precision, the statistics type, which maybe more // precise than the input types, are used for this op. - Type kernel_type = - scale.getType().template cast().getElementType(); + Type kernel_type = mlir::cast(scale.getType()).getElementType(); grad = rewriter.create(loc, grad, kernel_type); act = rewriter.create(loc, act, kernel_type); @@ -2260,14 +2256,13 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { auto feature_dim = getFeatureDimensionAttr(rewriter, data_format, op.getX()); - auto input_type_tensor = op.getX().getType().template cast(); + auto input_type_tensor = mlir::cast(op.getX().getType()); auto input_element_type = input_type_tensor.getElementType(); - auto scale_type_tensor = - op.getScale().getType().template cast(); + auto scale_type_tensor = mlir::cast(op.getScale().getType()); auto scale_element_type = scale_type_tensor.getElementType(); - auto mean_type_tensor = op.getMean().getType().template cast(); + auto mean_type_tensor = mlir::cast(op.getMean().getType()); auto mean_element_type = mean_type_tensor.getElementType(); // In the training case, dimensions of input tensors must be static. if (op.getIsTraining() && (!input_type_tensor.hasStaticShape() || @@ -2281,7 +2276,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { Value bn_train_input = rewriter.create( op.getLoc(), op.getX(), scale_element_type); TensorType bn_train_input_type_tensor = - bn_train_input.getType().template cast(); + mlir::cast(bn_train_input.getType()); if (op.getIsTraining()) { // Training case. @@ -2372,7 +2367,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // For FusedBatchNormV3Op, also create a constant tensor to forward to // last reserve_space_3 output. auto reserve_space_3_type = - op.getResult(5).getType().template cast(); + mlir::cast(op.getResult(5).getType()); int num_elements = reserve_space_3_type.hasStaticShape() ? reserve_space_3_type.getNumElements() : 0; @@ -2416,7 +2411,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // For FusedBatchNormV3Op, also create a constant tensor to forward to // last reserve_space_3 output. auto reserve_space_3_type = - op.getResult(5).getType().template cast(); + mlir::cast(op.getResult(5).getType()); int num_elements = reserve_space_3_type.hasStaticShape() ? reserve_space_3_type.getNumElements() : 0; @@ -2465,9 +2460,9 @@ static PaddingArray GetReduceWindowPaddingAsArray( for (const auto &dim : input_dims) input_shape.push_back(dim); for (Attribute attr : window_dims) - window_shape.push_back(attr.cast().getInt()); + window_shape.push_back(mlir::cast(attr).getInt()); for (Attribute attr : window_strides) - strides.push_back(attr.cast().getInt()); + strides.push_back(mlir::cast(attr).getInt()); PaddingArray paddings = ::xla::MakePadding(input_shape, window_shape, strides, ::xla::Padding::kSame); @@ -2509,8 +2504,7 @@ Operation *AvgPoolDivideByCount( const SmallVector &strides, OpTy op, Value zero, PatternRewriter &rewriter) { Location loc = op.getLoc(); - RankedTensorType pooled_type = - pooled.getType().template cast(); + RankedTensorType pooled_type = mlir::cast(pooled.getType()); Type element_type = pooled_type.getElementType(); Operation *result = nullptr; RankedTensorType orig_input_type = @@ -2577,8 +2571,7 @@ class ConvertAvgPoolOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Value input_value = GetAvgPoolInput(op); - auto input_type = - input_value.getType().template dyn_cast(); + auto input_type = mlir::dyn_cast(input_value.getType()); if (!input_type) return failure(); // We will do accumulation first; use a larger bitwidth if suitable. @@ -2587,7 +2580,7 @@ class ConvertAvgPoolOp : public OpRewritePattern { Type result_type; // The result type for reduction and division with the proper element type. - if (auto ranked_type = op.getType().template dyn_cast()) + if (auto ranked_type = mlir::dyn_cast(op.getType())) result_type = tensorflow::GetTypeFromTFTensorShape(ranked_type.getShape(), sum_element_type); else @@ -2695,8 +2688,7 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { // `out_grad` is the gradient that was propagated via backpropagation from // the output layer. Value out_grad = op.getGrad(); - auto out_grad_type = - out_grad.getType().template dyn_cast(); + auto out_grad_type = mlir::dyn_cast(out_grad.getType()); if (!out_grad_type) { return failure(); } @@ -2833,7 +2825,7 @@ class ConvertMaxPoolOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Type element_type = - op.getInput().getType().template cast().getElementType(); + mlir::cast(op.getInput().getType()).getElementType(); if (!element_type.isSignlessIntOrFloat()) return failure(); tensorflow::Padding padding; if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) @@ -2845,8 +2837,7 @@ class ConvertMaxPoolOp : public OpRewritePattern { ConstantOp init = GetScalarLimitConstOfType( element_type, loc, hlo::kInfinityLowest, &rewriter); - auto input_ty = - op.getInput().getType().template dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); if (!input_ty) return failure(); DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), @@ -2875,9 +2866,12 @@ class ConvertSelectOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::SelectOp op, PatternRewriter &rewriter) const override { // This lowering only works on ranked types. - auto cond_type = op.getCondition().getType().dyn_cast(); - auto then_type = op.getThenValue().getType().dyn_cast(); - auto else_type = op.getElseValue().getType().dyn_cast(); + auto cond_type = + mlir::dyn_cast(op.getCondition().getType()); + auto then_type = + mlir::dyn_cast(op.getThenValue().getType()); + auto else_type = + mlir::dyn_cast(op.getElseValue().getType()); if (!cond_type || !then_type || !else_type) { return failure(); } @@ -2913,7 +2907,7 @@ class ConvertSelectOp : public OpRewritePattern { assumption = b.createOrFold( witness, ValueRange{assumption, eq_cstr}); } - auto result_type = op.getResult().getType().cast(); + auto result_type = mlir::cast(op.getResult().getType()); auto assuming_op = b.create(ArrayRef{result_type}, assumption); @@ -2978,7 +2972,7 @@ class ConvertSigmoidOp : public RewritePattern { // Create constant half with shape and element type same as the operand. Value operand = op.getOperand(); - auto operand_ty = operand.getType().cast(); + auto operand_ty = mlir::cast(operand.getType()); auto scalar_ty = tensorflow::GetTypeFromTFTensorShape({}, operand_ty.getElementType()); ElementsAttr attr = mlir::hlo::getSplat(&rewriter, scalar_ty, 0.5); @@ -3009,9 +3003,9 @@ class ConvertSliceOpDynamic : public OpRewritePattern { Value begin_indices = op.getBegin(); Value sizes = op.getSize(); - auto input_ty = input.getType().dyn_cast(); - auto begin_type = begin_indices.getType().dyn_cast(); - auto size_type = sizes.getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(input.getType()); + auto begin_type = mlir::dyn_cast(begin_indices.getType()); + auto size_type = mlir::dyn_cast(sizes.getType()); if (!input_ty || !begin_type || !size_type || !begin_type.hasStaticShape() || !size_type.hasStaticShape() || @@ -3112,8 +3106,8 @@ static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc, loc, TypeRange{shape_type, shape_type}, lhs_shape, const_neg2); auto rhs_splitted = rewriter->create( loc, TypeRange{shape_type, shape_type}, rhs_shape, const_neg2); - auto lhs_type = lhs.getType().cast(); - auto rhs_type = rhs.getType().cast(); + auto lhs_type = mlir::cast(lhs.getType()); + auto rhs_type = mlir::cast(rhs.getType()); // The last two dimensions are the matrix row/col dimensions. Don't broadcast // them. SmallVector result_batch_shape_compile_time_extents; @@ -3166,21 +3160,21 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { PatternRewriter &rewriter) const override { Value lhs = op.getX(); Value rhs = op.getY(); - auto lhs_type = lhs.getType().dyn_cast(); - auto rhs_type = rhs.getType().dyn_cast(); + auto lhs_type = mlir::dyn_cast(lhs.getType()); + auto rhs_type = mlir::dyn_cast(rhs.getType()); if (!lhs_type || !rhs_type) return failure(); - if (lhs_type.getElementType().isa() && op.getAdjX()) { + if (mlir::isa(lhs_type.getElementType()) && op.getAdjX()) { lhs = rewriter.create(op.getLoc(), lhs_type, lhs); } - if (rhs_type.getElementType().isa() && op.getAdjY()) { + if (mlir::isa(rhs_type.getElementType()) && op.getAdjY()) { rhs = rewriter.create(op.getLoc(), rhs_type, rhs); } // Broadcast both operands. BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs, &rewriter); - lhs_type = lhs.getType().cast(); - rhs_type = rhs.getType().cast(); + lhs_type = mlir::cast(lhs.getType()); + rhs_type = mlir::cast(rhs.getType()); assert(lhs_type.getRank() == rhs_type.getRank()); int64_t rank = lhs_type.getRank(); auto batch_dimensions = llvm::to_vector<4>(llvm::seq(0, rank - 2)); @@ -3243,7 +3237,7 @@ class ConvertSplitOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::SplitOp op, PatternRewriter &rewriter) const override { // We can only split inputs that have fully static shape. - auto input_type = op.getValue().getType().dyn_cast(); + auto input_type = mlir::dyn_cast(op.getValue().getType()); if (!input_type || !input_type.hasStaticShape()) return failure(); // We can only match when the split dimension is a constant scalar. @@ -3304,7 +3298,7 @@ class ConvertSplitOpDynamic : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getValue(); - auto input_type = input.getType().dyn_cast(); + auto input_type = mlir::dyn_cast(input.getType()); if (!input_type) return failure(); // TODO(disc): remove static shape check once folding/canonicalization func @@ -3419,7 +3413,7 @@ class ConvertSplitVOp : public OpRewritePattern { PatternRewriter &rewriter) const override { // We can only split inputs that have fully static shape. // TODO(b/145731001): enhance to support dynamic-shaped inputs. - auto input_type = op.getValue().getType().dyn_cast(); + auto input_type = mlir::dyn_cast(op.getValue().getType()); if (!input_type || !input_type.hasStaticShape()) return failure(); // We can only match when the split dimension is a constant scalar. @@ -3438,7 +3432,7 @@ class ConvertSplitVOp : public OpRewritePattern { int64_t total_dim_size = 0; // Total dimension size assigned to splits std::optional dynamic_dim_index; split_sizes.reserve( - split_sizes_attr.getType().cast().getNumElements()); + mlir::cast(split_sizes_attr.getType()).getNumElements()); for (const auto &dim : llvm::enumerate(split_sizes_attr)) { int64_t dim_val = dim.value().getSExtValue(); split_sizes.push_back(dim_val); @@ -3620,7 +3614,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // Begin must be a ranked, 1-dimensional tensor: This is checked by the // verifier. int64_t slicing_dim_size = - op.getBegin().getType().cast().getDimSize(0); + mlir::cast(op.getBegin().getType()).getDimSize(0); uint64_t begin_mask = op.getBeginMask(); uint64_t end_mask = op.getEndMask(); const int input_rank = input_shape.size(); @@ -3642,7 +3636,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // For the dimensions that are to be sliced, all have slice sizes of 1. SmallVector slice_sizes; auto begin_element_ty = - op.getBegin().getType().cast().getElementType(); + mlir::cast(op.getBegin().getType()).getElementType(); // Scalar tensor type. TensorType type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, begin_element_ty); @@ -3696,14 +3690,14 @@ class ConvertStridedSliceOp : public OpRewritePattern { // // TODO(hinsu): Relax this constraint for ops without negative indices and // strides. - auto input_ty = op.getInput().getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); if (!input_ty || !input_ty.hasStaticShape()) return failure(); // Output shape needs to be static to apply 'new_axis_mask' or // 'shrink_axis_mask' by reshaping tensor after slice. // // TODO(hinsu): Relax this constraint for ops without the above masks. - auto result_ty = op.getType().dyn_cast(); + auto result_ty = mlir::dyn_cast(op.getType()); if (!result_ty || !result_ty.hasStaticShape()) return failure(); DenseIntElementsAttr sparse_begin_attr, sparse_end_attr; @@ -3750,7 +3744,7 @@ class ConvertStridedSliceGradOp return failure(); Value grad = op.getDy(); - Type element_type = grad.getType().cast().getElementType(); + Type element_type = mlir::cast(grad.getType()).getElementType(); // Perform reshape to undo any new/shrink axes done by strided slice. grad = rewriter.create( @@ -3830,7 +3824,7 @@ class ConvertRangeOp : public OpRewritePattern { PatternRewriter &rewriter) const override { auto result = op.getResult(); auto result_type = result.getType(); - if (!result_type.cast().hasStaticShape()) { + if (!mlir::cast(result_type).hasStaticShape()) { return failure(); } @@ -3863,7 +3857,7 @@ class ConvertDynamicRangeOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::RangeOp op, PatternRewriter &rewriter) const override { auto result = op.getResult(); - auto result_type = result.getType().cast(); + auto result_type = mlir::cast(result.getType()); if (result_type.hasStaticShape()) { return failure(); } @@ -3875,11 +3869,12 @@ class ConvertDynamicRangeOp : public OpRewritePattern { // To compute the length we need to use floating point calculations so that // ceil can be computed for the number of steps. auto compute_element_type = - getElementTypeOrSelf(start.getType()).isa() + mlir::isa(getElementTypeOrSelf(start.getType())) ? getElementTypeOrSelf(start.getType()) : rewriter.getF64Type(); auto compute_type = tensorflow::GetTypeFromTFTensorShape( - limit.getType().cast().getShape(), compute_element_type); + mlir::cast(limit.getType()).getShape(), + compute_element_type); // Compute the length of the sequence we are going to need. This includes // some conversion to float for the operations. @@ -3930,8 +3925,8 @@ class ConvertDynamicRangeOp : public OpRewritePattern { }; ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) { - auto int_attr = attr.cast(); - auto type = val.getType().cast(); + auto int_attr = mlir::cast(attr); + auto type = mlir::cast(val.getType()); SmallVector axis; axis.reserve(int_attr.getNumElements()); @@ -3954,7 +3949,7 @@ class ConvertLinSpaceOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::LinSpaceOp op, PatternRewriter &rewriter) const override { auto result = op.getResult(); - auto result_type = result.getType().dyn_cast(); + auto result_type = mlir::dyn_cast(result.getType()); if (!result_type || !result_type.hasStaticShape()) { return failure(); } @@ -4023,8 +4018,7 @@ class GenericConvertReductionOp : public OpRewritePattern { // TODO(b/141785544): Update this to not require ranked shapes. // Input shape needs to be ranked to convert negative indices in TensorFlow // to absolute indices required by HLO. - auto input_ty = - op.getInput().getType().template dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); if (!input_ty) return failure(); ArrayRef input_shape = input_ty.getShape(); @@ -4049,8 +4043,9 @@ class GenericConvertReductionOp : public OpRewritePattern { Type element_type = input_ty.getElementType(); // Only float, int, and complex types are currently supported. - if (!element_type.isa() && !element_type.isa() && - !element_type.isa()) { + if (!mlir::isa(element_type) && + !mlir::isa(element_type) && + !mlir::isa(element_type)) { return rewriter.notifyMatchFailure( op, "element type must be float, int, or complex type"); } @@ -4252,7 +4247,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { RankedTensorType input_type = - op.getInput().getType().template dyn_cast(); + mlir::dyn_cast(op.getInput().getType()); if (!input_type) { return failure(); } @@ -4267,7 +4262,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern { Derived::GetInitialValue(input_element_type, loc, rewriter); RankedTensorType output_type = - op.getOutput().getType().template dyn_cast(); + mlir::dyn_cast(op.getOutput().getType()); if (!output_type) { return rewriter.notifyMatchFailure(op, "requires known rank"); } @@ -4364,12 +4359,11 @@ class ConvertTensorScatterOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - auto tensor_ty = - op.getTensor().getType().template dyn_cast(); + auto tensor_ty = mlir::dyn_cast(op.getTensor().getType()); auto indices_ty = - op.getIndices().getType().template dyn_cast(); + mlir::dyn_cast(op.getIndices().getType()); auto updates_ty = - op.getUpdates().getType().template dyn_cast(); + mlir::dyn_cast(op.getUpdates().getType()); if (!tensor_ty || !indices_ty || !updates_ty) return failure(); // Last dimension of the indices needs to known at compile time for @@ -4421,13 +4415,13 @@ class ConvertTensorScatterOp : public OpRewritePattern { updates = rewriter.create( op->getLoc(), broadcast_to_type, op.getUpdates(), const_op); - updates_ty = updates.getType().template dyn_cast(); + updates_ty = mlir::dyn_cast(updates.getType()); } int64_t tensor_rank = tensor_ty.getRank(); int64_t indices_rank = indices_ty.getRank(); int64_t updates_rank = - updates.getType().template dyn_cast().getRank(); + mlir::dyn_cast(updates.getType()).getRank(); int64_t window_dims = tensor_rank - num_index_dims; auto dims_attr = ScatterDimensionNumbersAttr::get( @@ -4558,7 +4552,7 @@ class ConvertTileOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::TileOp op, PatternRewriter &rewriter) const override { - auto input_ty = op.getInput().getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); if (!input_ty || !input_ty.hasStaticShape()) return failure(); ArrayRef input_shape = input_ty.getShape(); Type element_type = input_ty.getElementType(); @@ -4639,7 +4633,7 @@ class ConvertTileOpDynamic : public OpRewritePattern { Location loc = op.getLoc(); Value input = op.getInput(); Value multiples = op.getMultiples(); - auto input_ty = input.getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(input.getType()); if (!input_ty) return failure(); // TODO(disc): Remove this constraint once fold and canonicalization // implemented. @@ -4659,7 +4653,7 @@ class ConvertTileOpDynamic : public OpRewritePattern { } } - auto multiples_ty = multiples.getType().dyn_cast(); + auto multiples_ty = mlir::dyn_cast(multiples.getType()); int64_t multiples_rank = multiples_ty.getRank(); // rank of multiples input of tf.TileOp must be 1 if (multiples_rank != 1) return failure(); @@ -4728,16 +4722,14 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Type element_type = op.getOrigInput() - .getType() - .template cast() - .getElementType(); + Type element_type = + mlir::cast(op.getOrigInput().getType()).getElementType(); // Compute paddings using the original input and kernel shape and strides. // Here, ReduceWindow op as used as the MaxPool op is lowered to the // ReduceWindow op. auto input_ty = - op.getOrigInput().getType().template dyn_cast(); + mlir::dyn_cast(op.getOrigInput().getType()); if (!input_ty) return failure(); DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), @@ -4798,9 +4790,8 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { return failure(); auto out_backprop_ty = - op.getOutBackprop().getType().template dyn_cast(); - auto filter_ty = - op.getFilter().getType().template dyn_cast(); + mlir::dyn_cast(op.getOutBackprop().getType()); + auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); // With the exception of out_backprop's batch dimension, out_backprop and // filter need to have static shape. Filter is validated here, out_backprop @@ -4824,7 +4815,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { } else { auto pack = op.getInputSizes().template getDefiningOp(); if (!pack || pack.getAxis() != 0) return failure(); - auto pack_ty = pack.getType().template dyn_cast(); + auto pack_ty = mlir::dyn_cast(pack.getType()); if (!pack_ty || pack_ty.getRank() != 1) return failure(); for (auto i = 0; i < pack_ty.getDimSize(0); ++i) { if (i == batch_dim) { @@ -4862,7 +4853,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { explicit_paddings.reserve(explicit_paddings_attr.size()); for (Attribute explicit_padding : explicit_paddings_attr) explicit_paddings.push_back( - explicit_padding.cast().getInt()); + mlir::cast(explicit_padding).getInt()); } ArrayRef filter_shape = filter_ty.getShape(); @@ -5029,9 +5020,8 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { return failure(); auto out_backprop_ty = - op.getOutBackprop().getType().template dyn_cast(); - auto input_ty = - op.getInput().getType().template dyn_cast(); + mlir::dyn_cast(op.getOutBackprop().getType()); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); for (RankedTensorType ty : {out_backprop_ty, input_ty}) if (!ty || !ty.hasStaticShape()) return failure(); @@ -5063,7 +5053,7 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { explicit_paddings.reserve(explicit_paddings_attr.size()); for (Attribute explicit_padding : explicit_paddings_attr) explicit_paddings.push_back( - explicit_padding.cast().getInt()); + mlir::cast(explicit_padding).getInt()); } constexpr int num_dims = num_spatial_dims + 2; @@ -5223,7 +5213,8 @@ class ConvertOneHotOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::OneHotOp op, PatternRewriter &rewriter) const override { - auto indices_ty = op.getIndices().getType().dyn_cast(); + auto indices_ty = + mlir::dyn_cast(op.getIndices().getType()); if (!indices_ty || !indices_ty.hasStaticShape()) return failure(); ArrayRef indices_shape = indices_ty.getShape(); Type element_type = indices_ty.getElementType(); @@ -5307,7 +5298,7 @@ class ConvertInfeedDequeueTupleOp result_types.reserve(op.getOutputs().size() + 1); for (const auto &output : op.getOutputs()) { Type ty = output.getType(); - if (auto tensor_ty = ty.dyn_cast()) { + if (auto tensor_ty = mlir::dyn_cast(ty)) { if (!tensor_ty.hasStaticShape()) return failure(); } result_types.push_back(ty); @@ -5412,7 +5403,7 @@ class ConvertTopKV2Op : public OpRewritePattern { if (!matchPattern(op.getK(), m_Constant(&k_attr))) return failure(); int64_t k = (*k_attr.begin()).getSExtValue(); - TensorType input_type = op.getInput().getType().cast(); + TensorType input_type = mlir::cast(op.getInput().getType()); if (!input_type.hasRank()) return failure(); int64_t input_rank = input_type.getRank(); int64_t last_dim_index = input_rank - 1; @@ -5436,7 +5427,7 @@ class ConvertUnpackOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::UnpackOp op, PatternRewriter &rewriter) const override { - auto value_type = op.getValue().getType().dyn_cast(); + auto value_type = mlir::dyn_cast(op.getValue().getType()); if (!value_type) return failure(); int64_t value_rank = value_type.getRank(); @@ -5482,7 +5473,7 @@ class ConvertUnpackOpDynamic : public OpRewritePattern { LogicalResult matchAndRewrite(TF::UnpackOp op, PatternRewriter &rewriter) const override { - auto value_type = op.getValue().getType().dyn_cast(); + auto value_type = mlir::dyn_cast(op.getValue().getType()); if (!value_type) return failure(); // TODO(disc): Remove this constraint once fold and canonicalization // implemented. @@ -5585,8 +5576,8 @@ class ConvertSigmoidGradOpDynamic : public OpRewritePattern { Location loc = op.getLoc(); Value y = op.getY(); Value dy = op.getDy(); - auto tp_y = y.getType().dyn_cast(); - auto tp_dy = dy.getType().dyn_cast(); + auto tp_y = mlir::dyn_cast(y.getType()); + auto tp_dy = mlir::dyn_cast(dy.getType()); if (!tp_y || !tp_dy) return failure(); // TODO(disc): Remove this constraint once fold and canonicalization @@ -5598,7 +5589,7 @@ class ConvertSigmoidGradOpDynamic : public OpRewritePattern { if (elem_tp.isSignlessInteger()) { attr = rewriter.getIntegerAttr(elem_tp, 1); } else { - assert(elem_tp.isa()); + assert(mlir::isa(elem_tp)); attr = rewriter.getFloatAttr(elem_tp, 1); } Value one = rewriter.create( @@ -5640,13 +5631,12 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - auto data_type = - op.getData().getType().template dyn_cast(); + auto data_type = mlir::dyn_cast(op.getData().getType()); if (!data_type) return failure(); int64_t data_rank = data_type.getRank(); auto segment_ids_type = - op.getSegmentIds().getType().template dyn_cast(); + mlir::dyn_cast(op.getSegmentIds().getType()); if (!segment_ids_type) return failure(); int64_t segment_ids_rank = segment_ids_type.getRank(); @@ -5766,7 +5756,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { return success(); }; - auto input_type = op.getValue().getType().dyn_cast(); + auto input_type = mlir::dyn_cast(op.getValue().getType()); if (!input_type) return failure(); if (input_type.hasStaticShape() && input_type.getNumElements() <= 1) // No shuffling is required, so copy input directly to output. @@ -5966,16 +5956,16 @@ class ConvertInplaceUpdateOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::InplaceUpdateOp op, PatternRewriter &rewriter) const override { - auto input = op.getX().dyn_cast>(); + auto input = mlir::dyn_cast>(op.getX()); if (!input) return failure(); auto indices = op.getI(); auto updates = op.getV(); // Slice each row of `i` and `v` to perform a separate dynamic-update-slice // on the contents of `x`. - auto input_type = input.getType().cast(); - auto updates_type = updates.getType().cast(); - auto indices_type = indices.getType().cast(); + auto input_type = mlir::cast(input.getType()); + auto updates_type = mlir::cast(updates.getType()); + auto indices_type = mlir::cast(indices.getType()); if (!input_type.hasRank()) return failure(); if (!updates_type.hasRank() || updates_type.isDynamicDim(0)) return failure(); @@ -6033,7 +6023,8 @@ class ConvertXlaDynamicUpdateSliceOp LogicalResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op, PatternRewriter &rewriter) const override { - auto indices_type = op.getIndices().getType().dyn_cast(); + auto indices_type = + mlir::dyn_cast(op.getIndices().getType()); if (!indices_type || !indices_type.hasStaticShape() || indices_type.getShape().size() != 1) return failure(); @@ -6062,8 +6053,8 @@ class ConvertXlaReduceScatterOp if (!matchPattern(op.getGroupAssignment(), m_Constant(&group_assignment))) return failure(); auto replica_groups = - hlo::convertElementsAttr(group_assignment, rewriter.getIntegerType(64)) - .cast(); + mlir::cast(hlo::convertElementsAttr( + group_assignment, rewriter.getIntegerType(64))); if (replica_groups.getType().getRank() != 2) return failure(); APInt scatter_dimension; @@ -6141,16 +6132,16 @@ class ConvertXlaReduceWindowOp // Create the mhlo.SelectAndScatter op. auto reduce_window_op = rewriter.create( loc, result_types, op.getInput(), op.getInitValue(), - hlo::convertElementsAttr(window_dimensions, rewriter.getIntegerType(64)) - .cast(), - hlo::convertElementsAttr(window_strides, rewriter.getIntegerType(64)) - .cast(), - hlo::convertElementsAttr(base_dilations, rewriter.getIntegerType(64)) - .cast(), - hlo::convertElementsAttr(window_dilations, rewriter.getIntegerType(64)) - .cast(), - hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)) - .cast()); + mlir::cast(hlo::convertElementsAttr( + window_dimensions, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + window_strides, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + base_dilations, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + window_dilations, rewriter.getIntegerType(64))), + mlir::cast( + hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)))); // Insert a call to the reducer in the region of the mhlo op. mlir::SymbolRefAttr func = op.getComputation(); auto func_op = cast(SymbolTable::lookupSymbolIn( @@ -6177,9 +6168,9 @@ class ConvertClipByValueOp : public OpRewritePattern { Value min = op.getClipValueMin(); Value max = op.getClipValueMax(); - auto input_ty = input.getType().cast(); - auto min_ty = min.getType().cast(); - auto max_ty = max.getType().cast(); + auto input_ty = mlir::cast(input.getType()); + auto min_ty = mlir::cast(min.getType()); + auto max_ty = mlir::cast(max.getType()); if (!input_ty.hasRank() || !min_ty.hasRank() || !max_ty.hasRank()) { return failure(); @@ -6215,8 +6206,9 @@ class ConvertConstOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::ConstOp op, PatternRewriter &rewriter) const override { // Convert only for valid HLO tensors. - auto ty = op.getType().dyn_cast(); - if (!ty || !ty.getElementType().isa()) + auto ty = mlir::dyn_cast(op.getType()); + if (!ty || + !mlir::isa(ty.getElementType())) return failure(); Location loc = op.getLoc(); @@ -6239,9 +6231,9 @@ class ConvertCumOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const override { - auto input = op.getX().template dyn_cast>(); + auto input = mlir::dyn_cast>(op.getX()); if (!input) return failure(); - auto input_type = input.getType().template dyn_cast(); + auto input_type = mlir::dyn_cast(input.getType()); if (!input_type || !input_type.hasStaticShape()) { return failure(); } @@ -6352,7 +6344,7 @@ class ConvertShapeOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Value input = op.getInput(); - auto result_ty = op.getResult().getType().dyn_cast(); + auto result_ty = mlir::dyn_cast(op.getResult().getType()); if (!result_ty) { return failure(); } @@ -6373,8 +6365,8 @@ class ConvertDynamicExpandDimsOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::ExpandDimsOp op, PatternRewriter &rewriter) const override { auto input = op.getInput(); - auto input_ty = input.getType().cast(); - auto result_ty = op.getType().cast(); + auto input_ty = mlir::cast(input.getType()); + auto result_ty = mlir::cast(op.getType()); if (!result_ty.hasRank() || !input_ty.hasRank() || result_ty.hasStaticShape()) { return failure(); @@ -6431,8 +6423,8 @@ class ConvertDynamicSqueezeOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::SqueezeOp op, PatternRewriter &rewriter) const override { auto input = op.getInput(); - auto input_ty = input.getType().cast(); - auto result_ty = op.getType().cast(); + auto input_ty = mlir::cast(input.getType()); + auto result_ty = mlir::cast(op.getType()); if (!result_ty.hasRank() || !input_ty.hasRank() || result_ty.hasStaticShape()) { return failure(); @@ -6465,9 +6457,7 @@ class ConvertDynamicSqueezeOp : public OpRewritePattern { auto from_extents = rewriter.create(op.getLoc(), dims); - // chlo::DynamicReshapeOp checks if the reshape is legal and will fail if - // any non-1 dimension is squeezed. - rewriter.replaceOpWithNewOp(op, result_ty, input, + rewriter.replaceOpWithNewOp(op, result_ty, input, from_extents); return success(); } @@ -6492,24 +6482,23 @@ class ConvertXlaConvV2Op : public OpRewritePattern { return failure(); auto window_strides_named_attr = rewriter.getNamedAttr( - "window_strides", hlo::convertElementsAttr(window_strides_attr, - rewriter.getIntegerType(64)) - .cast()); + "window_strides", + mlir::cast(hlo::convertElementsAttr( + window_strides_attr, rewriter.getIntegerType(64)))); auto padding_named_attr = rewriter.getNamedAttr( - "padding", - hlo::convertElementsAttr(padding_attr, rewriter.getIntegerType(64)) - .cast()); + "padding", mlir::cast(hlo::convertElementsAttr( + padding_attr, rewriter.getIntegerType(64)))); auto lhs_dilation_named_attr = rewriter.getNamedAttr( "lhs_dilation", - hlo::convertElementsAttr(lhs_dilation_attr, rewriter.getIntegerType(64)) - .cast()); + mlir::cast(hlo::convertElementsAttr( + lhs_dilation_attr, rewriter.getIntegerType(64)))); auto rhs_dilation_named_attr = rewriter.getNamedAttr( "rhs_dilation", - hlo::convertElementsAttr(rhs_dilation_attr, rewriter.getIntegerType(64)) - .cast()); + mlir::cast(hlo::convertElementsAttr( + rhs_dilation_attr, rewriter.getIntegerType(64)))); int64_t feature_group_count_val = feature_group_count_attr.getValues()[0].getInt(); @@ -6566,12 +6555,12 @@ class ConvertXlaSelectAndScatterOp // Create the mhlo.SelectAndScatter op. auto select_and_scatter_op = rewriter.create( loc, result_types, op.getOperand(), op.getSource(), op.getInitValue(), - hlo::convertElementsAttr(window_dimensions, rewriter.getIntegerType(64)) - .cast(), - hlo::convertElementsAttr(window_strides, rewriter.getIntegerType(64)) - .cast(), - hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)) - .cast()); + mlir::cast(hlo::convertElementsAttr( + window_dimensions, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + window_strides, rewriter.getIntegerType(64))), + mlir::cast( + hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)))); auto insert_call_to = [&](const mlir::SymbolRefAttr &func, Region *region) { auto func_op = cast(SymbolTable::lookupSymbolIn( @@ -6671,7 +6660,7 @@ class ConvertXlaVariadicReduceV2Op auto func_ty = func_op.getFunctionType(); SmallVector elementTypes{llvm::map_range( func_ty.getResults(), - [](Type ty) { return ty.cast().getElementType(); })}; + [](Type ty) { return mlir::cast(ty).getElementType(); })}; // Create the mhlo.reduce op. auto reduce_op = rewriter.create( @@ -6754,7 +6743,7 @@ class LowerYieldOp : public OpConversionPattern { // Returns a new tensor type from the given type with element type updated to // the given type. TensorType UpdateElementTypeTo(Type ty, Type element_ty) { - auto ranked_ty = ty.dyn_cast(); + auto ranked_ty = mlir::dyn_cast(ty); if (!ranked_ty) { return UnrankedTensorType::get(element_ty); } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc index 54bd5812644488..34df8fc9759a5c 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc @@ -113,9 +113,8 @@ LogicalResult ConvertReplicaGroups(OpBuilder& builder, if (!matchPattern(group_assignment_value, m_Constant(&group_assignment))) { return op->emitOpError() << "expects constant group_assignment"; } - replica_groups = - hlo::convertElementsAttr(group_assignment, builder.getIntegerType(64)) - .cast(); + replica_groups = mlir::cast( + hlo::convertElementsAttr(group_assignment, builder.getIntegerType(64))); if (replica_groups.getType().getRank() != 2) { return op->emitOpError() << "group_assignment should have rank 2, got " << replica_groups.getType().getRank(); diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc index 3e8dd5b58ed2f1..68c412f79ff393 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc @@ -458,7 +458,7 @@ SmallVector GetValueWithToken( return new_result; }; - auto tuple_type = value.getType().dyn_cast(); + auto tuple_type = mlir::dyn_cast(value.getType()); // `value` is not a tuple, create a new tuple. if (!tuple_type) return {create_tuple({value, token})}; @@ -499,7 +499,7 @@ SmallVector GetTypeWithToken(OpBuilder& builder, ArrayRef types, } auto type = types[0]; - if (auto tuple_type = type.dyn_cast()) { + if (auto tuple_type = mlir::dyn_cast(type)) { auto result_types = llvm::to_vector(tuple_type.getTypes()); result_types.push_back(token_type); return {builder.getTupleType(result_types)}; @@ -536,7 +536,7 @@ void ReplaceWithTupleResult(OpBuilder& builder, ValueRange values, auto value = values[0]; auto replacement = replacements[0]; - auto tuple_type = value.getType().dyn_cast(); + auto tuple_type = mlir::dyn_cast(value.getType()); if (!tuple_type) { if (!value.use_empty()) { auto new_element = builder.create(replacement.getLoc(), diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index 588639e3435aae..401d1e8b954e40 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -627,10 +627,6 @@ foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in { (addBenefit 2)>; } -// Lowering tf.Reshape with dynamic shape -def : Pat<(TF_ReshapeOp:$res MHLO_Tensor:$arg, $shape), - (CHLO_DynamicReshapeOp $arg, $shape)>; - // Returns NaN if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. def : Pat<(TF_SignOp $x), (MHLO_SignOp $x)>; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc index d5560f2481b00f..ce8b46708d2f52 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -75,13 +76,13 @@ namespace { // Returns true if the given type is a ranked tensor type with static or bounded // dimensions. bool IsBounded(Type ty) { - auto ranked_ty = ty.dyn_cast(); + auto ranked_ty = mlir::dyn_cast(ty); if (!ranked_ty) return false; if (ranked_ty.hasStaticShape()) return true; auto encoding = - ranked_ty.getEncoding().dyn_cast_or_null(); + mlir::dyn_cast_or_null(ranked_ty.getEncoding()); if (!encoding) return false; for (int i = 0; i < ranked_ty.getRank(); ++i) { @@ -96,10 +97,11 @@ bool IsBounded(Type ty) { bool HasSymbolRefAttr(Operation* op) { for (const auto& attr : op->getAttrs()) { Attribute attr_value = attr.getValue(); - if (attr_value.isa()) { + if (mlir::isa(attr_value)) { return true; - } else if (auto array_attr = attr_value.dyn_cast()) { - if (!array_attr.empty() && array_attr.begin()->isa()) { + } else if (auto array_attr = mlir::dyn_cast(attr_value)) { + if (!array_attr.empty() && + mlir::isa(*array_attr.begin())) { return true; } } @@ -146,8 +148,8 @@ class Tf2XlaRewritePattern : public ConversionPattern { }; bool ShouldRefineTypeTo(Type original_ty, Type updated_ty) { - auto updated = updated_ty.dyn_cast(); - auto original = original_ty.dyn_cast(); + auto updated = mlir::dyn_cast(updated_ty); + auto original = mlir::dyn_cast(original_ty); // Both types must be shaped types. if (!original || !updated) return false; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc index 142bd2b379208f..e43bcdf6d3a26e 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc @@ -35,7 +35,7 @@ using ::mlir::ModuleOp; using ::mlir::OwningOpRef; using ::tsl::StatusOr; -StatusOr> GetMlirModuleFromString( +absl::StatusOr> GetMlirModuleFromString( absl::string_view module_string, MLIRContext* context) { DialectRegistry mlir_registry; RegisterCommonToolingDialects(mlir_registry); diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h index 9a6aeb44a27279..13baaba06aadb9 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h @@ -28,7 +28,7 @@ namespace test { // Given a raw string, return a ModuleOp that can be used with the given // MLIRContext. -tsl::StatusOr> GetMlirModuleFromString( +absl::StatusOr> GetMlirModuleFromString( absl::string_view module_string, MLIRContext* mlir_context); } // namespace test diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc index b17d474f85a652..2709f9dada21a7 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc @@ -132,7 +132,7 @@ Tf2XlaRewriter::~Tf2XlaRewriter() { if (context_) context_->Unref(); } -tsl::StatusOr Tf2XlaRewriter::ImportXlaComputation( +absl::StatusOr Tf2XlaRewriter::ImportXlaComputation( XlaComputation& computation) { xla::DebugOptions debug_options; TF_ASSIGN_OR_RETURN(auto hlo_module_config, @@ -205,7 +205,7 @@ LogicalResult Tf2XlaRewriter::PrepareParams() { // concurrently running each of the MLIR functions create a new device. step_container_ = std::make_unique( /*step_id=*/0, cleanup); - tsl::Status status = step_container_->Create( + absl::Status status = step_container_->Create( device_->resource_manager(), tensorflow::XlaContext::kXlaContextResourceName, context_); if (!status.ok()) { @@ -214,7 +214,7 @@ LogicalResult Tf2XlaRewriter::PrepareParams() { } params_.step_container = step_container_.get(); - tsl::StatusOr version_or = tensorflow::GetTfGraphProducerVersion( + absl::StatusOr version_or = tensorflow::GetTfGraphProducerVersion( op_->getParentOfType()); if (!version_or.ok()) { return emitError(op_->getLoc()) << version_or.status().ToString(); @@ -232,13 +232,13 @@ LogicalResult Tf2XlaRewriter::PrepareParams() { // Returns true if the given type is a ranked tensor type with static or // bounded dimensions. bool IsBounded(Type ty) { - auto ranked_ty = ty.dyn_cast(); + auto ranked_ty = mlir::dyn_cast(ty); if (!ranked_ty) return false; if (ranked_ty.hasStaticShape()) return true; auto encoding = - ranked_ty.getEncoding().dyn_cast_or_null(); + mlir::dyn_cast_or_null(ranked_ty.getEncoding()); if (!encoding) return false; for (int i = 0; i < ranked_ty.getRank(); ++i) { @@ -253,10 +253,11 @@ bool IsBounded(Type ty) { bool HasSymbolRefAttr(Operation* op) { for (const auto& attr : op->getAttrs()) { Attribute attr_value = attr.getValue(); - if (attr_value.isa()) { + if (mlir::isa(attr_value)) { return true; - } else if (auto array_attr = attr_value.dyn_cast()) { - if (!array_attr.empty() && array_attr.begin()->isa()) { + } else if (auto array_attr = mlir::dyn_cast(attr_value)) { + if (!array_attr.empty() && + mlir::isa(*array_attr.begin())) { return true; } } @@ -305,7 +306,7 @@ LogicalResult Tf2XlaRewriter::PrepareKernelInputs( LogicalResult Tf2XlaRewriter::LegalizeOp() { for (Type ty : op_->getOperandTypes()) { - auto ranked_ty = ty.dyn_cast(); + auto ranked_ty = mlir::dyn_cast(ty); // Only bounded operands are supported in the XLA builders. if (!IsBounded(ranked_ty)) { return op_->emitRemark() @@ -328,7 +329,7 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() { if (failed(PrepareParams())) return failure(); std::shared_ptr props; - tsl::Status status = tensorflow::NodeProperties::CreateFromNodeDef( + absl::Status status = tensorflow::NodeProperties::CreateFromNodeDef( *nodedef_or.value(), params_.function_library->GetFunctionLibraryDefinition(), &props); if (!status.ok()) { @@ -387,11 +388,11 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() { if (failed(VerifyOpResults(op_context))) return failure(); - StatusOr tuple_result_or_status = - CompileWithHloImporter(op_context); - if (!tuple_result_or_status.ok()) { - return op_->emitRemark() << tuple_result_or_status.status().ToString(); - } + absl::StatusOr tuple_result_or_status = + CompileWithHloImporter(op_context); + if (!tuple_result_or_status.ok()) { + return op_->emitRemark() << tuple_result_or_status.status().ToString(); + } mhlo::TupleOp tuple_result = tuple_result_or_status.value(); llvm::SmallVector output_values; @@ -403,7 +404,7 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() { return success(); } -tsl::StatusOr Tf2XlaRewriter::CompileWithHloImporter( +absl::StatusOr Tf2XlaRewriter::CompileWithHloImporter( tensorflow::OpKernelContext& op_context) { // XLA can only return a single value. Wrap all output op return values // in a Tuple op that gets unpacked later. diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h index 71cafc5579ff16..2b8c52750a6c44 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h @@ -56,12 +56,12 @@ class Tf2XlaRewriter { // Compiles the given Operation with XlaBuilder and imports the generated HLO // via the HLO -> MHLO importer. - tsl::StatusOr CompileWithHloImporter( + absl::StatusOr CompileWithHloImporter( tensorflow::OpKernelContext& op_context); // Import the given XlaComputation into the parent module. Returns the given // generated function. - tsl::StatusOr ImportXlaComputation( + absl::StatusOr ImportXlaComputation( xla::XlaComputation& computation); // Prepares OpKernelContext params common to all the ops. diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc index 061889965aebd9..aecf9db3f0d5fe 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc @@ -127,7 +127,7 @@ class Tf2XlaRewriterTest : public ::testing::Test { module_, test::GetMlirModuleFromString(module_string, &context_)); context_.loadAllAvailableDialects(); - return tsl::OkStatus(); + return absl::OkStatus(); } Status LegalizeSingleOp(Operation& op) { @@ -143,7 +143,7 @@ class Tf2XlaRewriterTest : public ::testing::Test { return tsl::errors::Internal("Failed to rewrite op"); } - return tsl::OkStatus(); + return absl::OkStatus(); } Status LegalizeModule(std::string module_string = kMlirModuleStr) { @@ -170,7 +170,7 @@ class Tf2XlaRewriterTest : public ::testing::Test { return tsl::errors::Internal("Could not legalize all ops"); } - return tsl::OkStatus(); + return absl::OkStatus(); } mlir::func::FuncOp GetMainFunc() { diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc index 7938fc4684ce2b..a6435081820880 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -89,18 +90,18 @@ static void IncrementCounterFor(tensorflow::monitoring::Counter<1>* counter, } bool HasBounds(RankedTensorType type) { - auto encoding = - type.getEncoding().dyn_cast_or_null(); + auto encoding = mlir::dyn_cast_or_null( + type.getEncoding()); return (encoding && !encoding.getBounds().empty()); } bool HasStaticShapeOrBounded(Value val) { auto type = val.getType(); - if (type.isa()) { + if (mlir::isa(type)) { return false; } - if (type.isa()) { - auto ranked_tensor = type.dyn_cast(); + if (mlir::isa(type)) { + auto ranked_tensor = mlir::dyn_cast(type); if (ranked_tensor.hasStaticShape()) { return true; } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_test.cc index 39eadcb93fcfce..4183d181fc5611 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_test.cc @@ -40,7 +40,7 @@ using ::mlir::OwningOpRef; using ::mlir::PassManager; using ::tensorflow::monitoring::testing::CellReader; -StatusOr> GetMlirModuleFromString( +absl::StatusOr> GetMlirModuleFromString( absl::string_view module_string, MLIRContext* context) { mlir::DialectRegistry mlir_registry; RegisterAllTensorFlowDialects(mlir_registry); diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 1ce45fe7345c11..4583fc9cd967e2 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -38,9 +38,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tosa/tfl_passes.h" #include "tensorflow/compiler/mlir/tosa/transforms/passes.h" #include "xla/mlir/framework/transforms/passes.h" -#include "xla/mlir_hlo/lhlo/transforms/passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" -#include "xla/service/cpu/hlo_xla_runtime_pipeline.h" int main(int argc, char **argv) { tensorflow::InitMlir y(&argc, &argv); @@ -52,7 +50,6 @@ int main(int argc, char **argv) { mlir::tf_saved_model::registerTensorFlowSavedModelPasses(); mlir::TFL::registerTensorFlowLitePasses(); mlir::mhlo::registerAllMhloPasses(); - mlir::lmhlo::registerAllLmhloPasses(); // These are in compiler/mlir/tf2xla and not part of the above MHLO passes. mlir::mhlo::registerLegalizeTfPasses(); diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc index 1e5114c1103c1f..906f828f2d5023 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc @@ -57,7 +57,7 @@ namespace tfr { const char* const kTFRLibEnv = "TF_MLIR_TFR_LIB_DIR"; -StatusOr> TFRDecomposeContext::Get( +absl::StatusOr> TFRDecomposeContext::Get( mlir::MLIRContext* mlir_ctx) { Env* env = Env::Default(); std::string tfr_lib_dir; @@ -121,8 +121,8 @@ std::unique_ptr TFRDecomposeContext::GetFromText( return std::make_unique(module_op); } -StatusOr TFRDecomposeContext::ExpandNode(const NodeDef& node_def, - StringPiece func_name) { +absl::StatusOr TFRDecomposeContext::ExpandNode( + const NodeDef& node_def, StringPiece func_name) { const OpDef* op_def; TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def)); DataTypeVector input_dtys, output_dtys; @@ -209,8 +209,8 @@ TFRDecomposeContext::TFRDecomposeContext(mlir::ModuleOp tfr_module) void TFRDecomposeContext::Destroy() { tfr_module_.erase(); } -StatusOr ExpandNode(const NodeDef& node_def, - StringPiece func_name) { +absl::StatusOr ExpandNode(const NodeDef& node_def, + StringPiece func_name) { mlir::MLIRContext mlir_ctx; TF_ASSIGN_OR_RETURN(auto ctx, TFRDecomposeContext::Get(&mlir_ctx)); return ctx->ExpandNode(node_def, func_name); diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td index 0e036caf4d2c77..b3bd4d618bd808 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td @@ -41,7 +41,6 @@ def TFR_Dialect : Dialect { }]; let cppNamespace = "::mlir::TFR"; - let usePropertiesForAttributes = 0; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose.cc b/tensorflow/compiler/mlir/tfr/passes/decompose.cc index 5d59d958d3e7c9..988dc9e612b9c3 100644 --- a/tensorflow/compiler/mlir/tfr/passes/decompose.cc +++ b/tensorflow/compiler/mlir/tfr/passes/decompose.cc @@ -84,8 +84,8 @@ namespace { // Quantize the float value based on given scale and zero point attributes. IntegerAttr Quantize(float value, Attribute scale_attr, Attribute zp_attr, OpBuilder builder) { - double scale = scale_attr.cast().getValueAsDouble(); - int64_t zp = zp_attr.cast().getInt(); + double scale = mlir::cast(scale_attr).getValueAsDouble(); + int64_t zp = mlir::cast(zp_attr).getInt(); int quantized = static_cast(std::round(value / scale) + zp); quantized = @@ -187,11 +187,12 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { // default value in the argument attribute. llvm::SmallVector new_operands; for (auto arg : llvm::enumerate(compose_func_type.getInputs())) { - if (auto tensor_type = arg.value().dyn_cast()) { + if (auto tensor_type = mlir::dyn_cast(arg.value())) { auto casted = builder.create(op->getLoc(), tensor_type, op->getOperand(arg.index())); new_operands.push_back(casted); - } else if (auto list_type = arg.value().dyn_cast()) { + } else if (auto list_type = + mlir::dyn_cast(arg.value())) { llvm::SmallVector variadic_operands; for (int i = arg.index(); i < op->getNumOperands(); i++) { auto casted = builder.create( @@ -211,8 +212,8 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { } if (!attribute && attr_name.getValue() == "out_type") { auto type = op->getResult(0).getType(); - if (type.isa()) { - type = type.cast().getElementType(); + if (mlir::isa(type)) { + type = mlir::cast(type).getElementType(); } attribute = TypeAttr::get(type); } @@ -220,8 +221,9 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { // Wrap these special attributes as a special TFR constant, so the SSA // value has a valid type to be used as TFR function argument. These // attributes are not expected to be manipulated by the lowering passes. - if (attribute.isa() || attribute.isa() || - attribute.isa() || attribute.isa()) { + if (mlir::isa(attribute) || mlir::isa(attribute) || + mlir::isa(attribute) || + mlir::isa(attribute)) { TFRAttrType output_type = TFRAttrType::get(builder.getContext()); attr_cst = builder.create(op->getLoc(), output_type, attribute); @@ -245,9 +247,10 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { // op result. llvm::SmallVector new_results; for (auto res : llvm::enumerate(compose_func_type.getResults())) { - if (res.value().dyn_cast()) { + if (mlir::dyn_cast(res.value())) { new_results.push_back(new_op.getResult(res.index())); - } else if (auto list_type = res.value().dyn_cast()) { + } else if (auto list_type = + mlir::dyn_cast(res.value())) { for (int i = res.index(), j = 0; i < op->getNumResults(); i++, j++) { auto index = builder.create( op->getLoc(), builder.getIndexAttr(j)); diff --git a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc index dd85565cfed88e..61aa404847ee07 100644 --- a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc +++ b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc @@ -136,7 +136,7 @@ class RewriteTFRCallOp : public OpRewritePattern { // by the frontend correctly. Value CastToNonDerivedType(PatternRewriter& rewriter, Location loc, CastOp cast_op, Type input_tfr_type) const { - auto tensor_type = input_tfr_type.dyn_cast(); + auto tensor_type = mlir::dyn_cast(input_tfr_type); if (!tensor_type) return cast_op.getArg(); auto attr_names = tensor_type.getAttrKeys(); @@ -150,7 +150,7 @@ class RewriteTFRCallOp : public OpRewritePattern { } Type original_input_type = - cast_op.getInputElementType().cast().getValue(); + mlir::cast(cast_op.getInputElementType()).getValue(); if (result_elt_type != original_input_type) { UnrankedTensorType result_type = UnrankedTensorType::get(result_elt_type); return rewriter.create(loc, result_type, cast_op.getArg()); @@ -166,10 +166,10 @@ class RewriteTFRCallOp : public OpRewritePattern { llvm::SmallVectorImpl& input_values) const { if (input_types.size() <= 1) return; - Type target_input_type = input_types[0].cast().getValue(); + Type target_input_type = mlir::cast(input_types[0]).getValue(); auto result_type = UnrankedTensorType::get(target_input_type); for (auto i = 1; i < input_types.size(); ++i) { - Type current_input_type = input_types[i].cast().getValue(); + Type current_input_type = mlir::cast(input_types[i]).getValue(); if (current_input_type != target_input_type) { input_values[i] = rewriter.create(loc, result_type, input_values[i]); @@ -189,7 +189,7 @@ LogicalResult RewriteTFRCallOp::AddDerivedAttrs( llvm::StringMap* derived_attrs) const { // If there is an attribute associated to the input in the signature, we // store it as an derived attribute. - if (auto tensor_type = input_tfr_type.dyn_cast()) { + if (auto tensor_type = mlir::dyn_cast(input_tfr_type)) { auto attr_names = tensor_type.getAttrKeys(); if (attr_names.empty()) return success(); @@ -201,7 +201,7 @@ LogicalResult RewriteTFRCallOp::AddDerivedAttrs( // If there is an attribute associated to the input in the signature, // we store it as an derived attribute. - if (auto list_type = input_tfr_type.dyn_cast()) { + if (auto list_type = mlir::dyn_cast(input_tfr_type)) { auto attr_names = list_type.getAttrKeys(); if (attr_names.empty()) return success(); @@ -314,7 +314,7 @@ Attribute RewriteTFRCallOp::ProcessAttributeValue(Attribute attr, if (!attr_type) return attr; if (attr_type.getValue() == "tensor") { - if (auto f = attr.dyn_cast()) { + if (auto f = mlir::dyn_cast(attr)) { RankedTensorType type = RankedTensorType::get({}, f.getType()); return DenseFPElementsAttr::get(type, attr); } @@ -332,13 +332,13 @@ LogicalResult RewriteTFRCallOp::DeriveOutputTypes( const llvm::StringMap& attrs, SmallVectorImpl* output_types) const { for (auto res : llvm::enumerate(signature.getResults())) { - if (auto tensor_type = res.value().dyn_cast()) { + if (auto tensor_type = mlir::dyn_cast(res.value())) { // tfr.tensor should only have one attribute attached. auto attr_key = tensor_type.getAttrKeys().front(); Builder builder(signature.getContext()); if (auto attr = attrs.lookup(attr_key.getValue())) { output_types->push_back( - UnrankedTensorType::get(attr.cast().getValue())); + UnrankedTensorType::get(mlir::cast(attr).getValue())); } else if (Type element_type = GetFixedElementType(attr_key.getValue(), builder)) { output_types->push_back(UnrankedTensorType::get(element_type)); @@ -350,16 +350,18 @@ LogicalResult RewriteTFRCallOp::DeriveOutputTypes( continue; } - if (auto list_type = res.value().dyn_cast()) { + if (auto list_type = mlir::dyn_cast(res.value())) { // There are two cases: N*T or list(dtype) auto attr_keys = list_type.getAttrKeys(); // N*T case if (attr_keys.size() == 2) { // The first one is N, and the second one is T int list_size = - attrs.lookup(attr_keys[0].getValue()).cast().getInt(); + mlir::cast(attrs.lookup(attr_keys[0].getValue())) + .getInt(); Type list_type = - attrs.lookup(attr_keys[1].getValue()).cast().getValue(); + mlir::cast(attrs.lookup(attr_keys[1].getValue())) + .getValue(); for (int i = 0; i < list_size; ++i) { output_types->push_back(UnrankedTensorType::get(list_type)); } @@ -398,11 +400,12 @@ LogicalResult RewriteTFRCallOp::CreateAndReplaceOp( SmallVector new_results; for (auto res : llvm::enumerate(call_op.getResultTypes())) { Type res_type = res.value(); - if (res_type.dyn_cast()) { + if (mlir::dyn_cast(res_type)) { Value new_res = new_op->getResult(res.index()); auto casted = rewriter.create(loc, res_type, new_res); new_results.push_back(casted.getOut()); - } else if (auto list_type = res.value().dyn_cast()) { + } else if (auto list_type = + mlir::dyn_cast(res.value())) { SmallVector tensor_list; for (int i = res.index(); i < new_op->getNumResults(); i++) { Value new_res = new_op->getResult(i); diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 31b6aa272faf1d..1b6dbbd9176d22 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -145,6 +145,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@tf_runtime//:basic_kernels_opdefs", @@ -166,6 +167,7 @@ cc_library( "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_opdefs", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@tf_runtime//:basic_kernels_opdefs", @@ -206,6 +208,7 @@ cc_library( "transforms/deduplicate_if_result_pass.cc", "transforms/fuse_tpu_compile_and_execute_ops.cc", "transforms/insert_tensor_copy.cc", + "transforms/lower_bound_batch_threads.cc", "transforms/lower_saved_model.cc", "transforms/merge_tf_if_ops.cc", "transforms/optimize.cc", @@ -289,22 +292,21 @@ cc_library( deps = [ ":tf_to_tfrt", ":tfrt_compile_options", + ":tfrt_pipeline_options", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@tf_runtime//:bef", "@tf_runtime//:core_runtime", - "@tf_runtime//:hostcontext", "@tf_runtime//:mlirtobef", - "@tf_runtime//:tensor", ], ) @@ -331,6 +333,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:errors", "@tf_runtime//:bef", "@tf_runtime//:core_runtime", @@ -347,7 +350,6 @@ cc_library( "translate/import_model.h", ], visibility = [ - # copybara:uncomment "//learning/brain/experimental/tfrt/visualization:__pkg__", "//tensorflow/compiler/mlir/tfrt/tests/saved_model:__pkg__", "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:__pkg__", "//tensorflow/core/tfrt/graph_executor:__pkg__", @@ -406,10 +408,7 @@ cc_library( hdrs = ["translate/tfrt_compile_options.h"], compatible_with = get_compatible_with_portable(), # copybara: comment visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core/protobuf:for_core_protos_cc", - "@com_google_absl//absl/strings", - ], + deps = ["//tensorflow/core/protobuf:for_core_protos_cc"], ) cc_library( @@ -426,6 +425,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@tf_runtime//:compiler_tfrt_op_interfaces", ], ) @@ -621,6 +621,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@tf_runtime//:core_runtime_opdefs", ], ) @@ -641,6 +642,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:core_runtime_opdefs", ], @@ -662,6 +664,7 @@ cc_library( hdrs = ["transforms/tpu_passes.h"], visibility = [":friends"] + if_google([ "//learning/brain/tfrt/ifrt/pjrt/__subpackages__", + "//learning/serving/servables/tfrt:__subpackages__", ]), deps = [ ":fallback_converter", diff --git a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc index 5573e7c2d46866..28f582723c8b2f 100644 --- a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc +++ b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/constants.h" #include "tensorflow/core/tfrt/fallback/cost_recorder.h" @@ -59,14 +60,14 @@ int64_t InferLookupTableFindV2Cost(const CostContext& context, constexpr int64_t kLookupTableFindCostScale = 8; constexpr int64_t kLookupTableFindStringKeyCostScale = 16; - auto value_type = op.getValues().getType().cast(); - auto key_type = op.getKeys().getType().cast(); + auto value_type = mlir::cast(op.getValues().getType()); + auto key_type = mlir::cast(op.getKeys().getType()); int64_t output_size = InferTensorSize(context, value_type); int64_t cost = kLookupTableFindCostScale * output_size; - if (key_type.getElementType().isa()) + if (mlir::isa(key_type.getElementType())) cost *= kLookupTableFindStringKeyCostScale; return cost; @@ -74,15 +75,15 @@ int64_t InferLookupTableFindV2Cost(const CostContext& context, // The cost function for tf.GatherV2. int64_t InferGatherV2Cost(const CostContext& context, mlir::TF::GatherV2Op op) { - return InferTensorSize(context, - op.getOutput().getType().cast()); + return InferTensorSize( + context, mlir::cast(op.getOutput().getType())); } // The cost function for tf.SparseSegmentSumOp. template int64_t InferSparseSegmentOpCost(const CostContext& context, OpType op) { return InferTensorSize( - context, op.getOutput().getType().template cast()); + context, mlir::cast(op.getOutput().getType())); } // CostFunctionRegistry is a map from op names to their cost functions. @@ -145,8 +146,8 @@ void CostAnalysis::AnalyzeArguments(mlir::func::FuncOp func_op) { // Use the max size among function inputs as the default size of dynamic // shaped tensors in the function. for (auto arg : func_op.getArguments()) { - if (!arg.getType().isa()) continue; - auto type = arg.getType().cast(); + if (!mlir::isa(arg.getType())) continue; + auto type = mlir::cast(arg.getType()); if (type.hasRank()) { max_arg_size_ = std::max(max_arg_size_, GetRankedTensorSize(type)); } @@ -204,7 +205,7 @@ void CostAnalysis::EvaluateCost(mlir::Operation* op) { // For other ops, use the sum of input sizes as its cost. int64_t cost = kDefaultCheapCost; for (auto operand : op->getOperands()) { - auto type = operand.getType().cast(); + auto type = mlir::cast(operand.getType()); if (type.hasRank()) { cost += GetRankedTensorSize(type); } else { diff --git a/tensorflow/compiler/mlir/tfrt/function/function.cc b/tensorflow/compiler/mlir/tfrt/function/function.cc index 42b7ff2b38982a..c29b5aeabda8ea 100644 --- a/tensorflow/compiler/mlir/tfrt/function/function.cc +++ b/tensorflow/compiler/mlir/tfrt/function/function.cc @@ -15,23 +15,20 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/function/function.h" +#include "absl/log/log.h" #include "absl/strings/match.h" -#include "absl/strings/str_split.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tfrt/bef/bef_buffer.h" // from @tf_runtime #include "tfrt/bef_converter/mlir_to_bef.h" // from @tf_runtime -#include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime -#include "tfrt/core_runtime/op_handler.h" // from @tf_runtime -#include "tfrt/host_context/host_context.h" // from @tf_runtime -#include "tfrt/tensor/dense_host_tensor_view.h" // from @tf_runtime namespace tensorflow { @@ -93,7 +90,7 @@ Status CompileTFMLIRToBEF(const TfrtFunctionCompileOptions& options, return diag_handler.Combine( tensorflow::errors::Internal("failed to convert MLIR to BEF.")); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/ir/BUILD b/tensorflow/compiler/mlir/tfrt/ir/BUILD index 68e9624e118453..b29066807fbf78 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/BUILD @@ -29,6 +29,7 @@ cc_library( ":tfrt_fallback_opdefs_inc_gen", "@llvm-project//mlir:IR", "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", ], ) @@ -43,7 +44,6 @@ cc_library( # copybara:uncomment "//learning/brain/experimental/tfrt:__subpackages__", # copybara:uncomment "//learning/brain/tfrt/tpu/compiler/mlir:__subpackages__", "//tensorflow/compiler/mlir/tfrt:__subpackages__", - "//tensorflow/core/runtime_fallback:__subpackages__", "//tensorflow/core/tfrt/saved_model:friends", ], deps = [ @@ -98,6 +98,7 @@ cc_library( deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@tf_runtime//:basic_kernels_opdefs", ], ) diff --git a/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.td b/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.td index fce1756d11df31..e8ba2fc4a47ac1 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.td +++ b/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.td @@ -30,7 +30,6 @@ def TFRT_GPU_Dialect : Dialect { }]; let cppNamespace = "::tfrt::gpu"; - let usePropertiesForAttributes = 0; } class Gpu_Op traits = []> : diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD index bfc93b9252ccbf..374aad2a242d9b 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD @@ -14,7 +14,6 @@ td_library( includes = ["."], visibility = [ # copybara:uncomment "//learning/brain/tfrt/mlir:__subpackages__", - "//learning/infra/mira/distributed:__subpackages__", ], deps = [ "@llvm-project//mlir:OpBaseTdFiles", @@ -51,7 +50,6 @@ cc_library( ], visibility = [ # copybara:uncomment "//learning/brain/tfrt/mlir:__subpackages__", - "//learning/infra/mira/distributed:__subpackages__", "//tensorflow/compiler/mlir/tfrt:__subpackages__", ], deps = [ @@ -59,6 +57,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:Support", ], ) @@ -70,9 +69,6 @@ td_library( "tf_ops.td", ], includes = ["."], - visibility = [ - # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", - ], deps = [ ":mlrt_td_files", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", @@ -155,7 +151,6 @@ cc_library( hdrs = ["tf_mlrt_ops.h"], visibility = [ # copybara:uncomment "//learning/brain/experimental/tfrt/mlrt/application/tensorflow/tests:__subpackages__", - # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", "//tensorflow/compiler/mlir/tfrt:__subpackages__", ], deps = [ @@ -167,6 +162,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@tf_runtime//:compiler_tfrt_op_interfaces", "@tf_runtime//:compiler_tfrt_traits", diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.cc b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.cc index 50d4cb1214250b..b4e337f328b27b 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.cc +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h" @@ -73,17 +74,17 @@ mlir::Type MlrtDialect::parseType(mlir::DialectAsmParser &parser) const { // Print a type registered to this dialect. void MlrtDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &os) const { - if (type.isa()) { + if (mlir::isa(type)) { os << "future"; return; } - if (type.isa()) { + if (mlir::isa(type)) { os << "promise"; return; } - if (type.isa()) { + if (mlir::isa(type)) { os << "async_handle"; return; } diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td index 94416661455a9c..b260dcb402f3f2 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td @@ -26,7 +26,6 @@ def Mlrt_Dialect : Dialect { }]; let cppNamespace = "::mlrt::compiler"; - let usePropertiesForAttributes = 0; } def MlrtFutureType : DialectType traits = []> : diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.cc b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.cc index fc4cb6a93a28ea..d6ddc8f96fd901 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/DialectImplementation.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" @@ -74,7 +75,7 @@ mlir::Type TensorflowMlrtDialect::parseType( // Print a type registered to this dialect. void TensorflowMlrtDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &os) const { - if (type.isa()) { + if (mlir::isa(type)) { os << "tensor"; return; } diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td index fcbf2358b3b936..0659143f49b39b 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td @@ -449,27 +449,21 @@ def IfrtLoadVariableOp: TensorflowMlrt_Op<"ifrt_load_variable", [Pure]> { let summary = "Loads a variable tensor as an IFRT array for mlrt"; let description = [{ - This is the MLRT version of tf.IfrtLoadVariableOp. + This op loads a restored variable tensor as a tensor future. It is a + replacement of `tf.ReadVariableOp`. - This op loads a variable tensor as an IFRT array and binds it with the specified name. + This op returns a scalar string tensor containing the restored variable name, which can be + used as a key within the runtime, as well as a future for the tensor. - This op is an replacement of `tf.ReadVariableOp` in the case that a constant - variable tensor is an input to the tpu program invoked by `tf.IfrtCall`. - - After a `tf.ReadVariableOp` is lowered into `tf.IfrtLoadVariableOp`, the `tf.IfrtCall` kernel - will bind the loaded IFRT array by name with the tpu program's input. - - `tf.IfrtLoadVariableOp` converts the tensor into an IFRT array based on device and sharding - configuration specified in `VariableDeviceShardingConfigProto`. - - This op returns a scalar string tensor as a key for user to look for the loaded array - and a future containing the restored tensor. + The `tf.IfrtCall` kernel uses the output $array_key. + Other ops executed by TFRT may make use of $tensor_future. }]; let arguments = (ins TFTensorType:$variable, StrAttr:$device_sharding_config_proto_text, - StrAttr:$name + StrAttr:$name, + DefaultValuedAttr:$used_by_host ); let results = (outs diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.td index 0791423a91c17f..fa08ea5907ac81 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.td +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.td @@ -172,7 +172,8 @@ def TFIfrtLoadVariableOp: TensorflowMlrt_Op<"tf_ifrt_load_variable", [Pure]> { let arguments = (ins TF_Tensor:$variable, StrAttr:$device_sharding_config_proto_text, - StrAttr:$name + StrAttr:$name, + DefaultValuedAttr:$used_by_host ); let results = (outs diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.cc b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.cc index 4bc8a6842bffe1..dd47e81ee1a6df 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.cc +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.cc @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/IR/DialectImplementation.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace tfrt { namespace fallback { @@ -47,12 +48,12 @@ Type FallbackDialect::parseType(DialectAsmParser &parser) const { /// Print a type registered to this dialect. void FallbackDialect::printType(Type type, DialectAsmPrinter &os) const { - if (type.isa()) { + if (mlir::isa(type)) { os << "tf_tensor"; return; } - if (type.isa()) { + if (mlir::isa(type)) { os << "tf_allocator"; return; } diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td index 0f74b0feca1821..0c42590f9aa7ee 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td @@ -27,7 +27,6 @@ def Fallback_Dialect : Dialect { }]; let cppNamespace = "::tfrt::fallback"; - let usePropertiesForAttributes = 0; } // This corresponds to tensorflow::Tensor. diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.td b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.td index 8c8bcd0ab4ffac..5dd788ef328858 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.td +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.td @@ -33,7 +33,6 @@ def FallbackAsync_Dialect : Dialect { }]; let cppNamespace = "::tfrt::fallback_async"; - let usePropertiesForAttributes = 0; } class FallbackAsync_Op traits = []> : diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc index 3a835b3796962d..30f6aa234a2d59 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace tfrt { namespace fallback_common { @@ -31,8 +32,8 @@ void GetExecuteOpAttrsCommon( mlir::Builder builder(context); for (auto iter : op_attr_array) { - auto key_value = iter.cast().getValue(); - llvm::StringRef key = key_value[0].cast().getValue(); + auto key_value = mlir::cast(iter).getValue(); + llvm::StringRef key = mlir::cast(key_value[0]).getValue(); mlir::Attribute value = key_value[1]; op_attrs->push_back({key, value}); } diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h index e78d247c038c64..0cddb1017a33d8 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime namespace tfrt { @@ -30,9 +31,9 @@ template mlir::LogicalResult VerifyExecuteOpCommon(OpTy op) { auto op_attr_array = op.getOpAttrs().getValue(); for (auto op_attr : op_attr_array) { - auto key_value = op_attr.template dyn_cast(); + auto key_value = mlir::dyn_cast(op_attr); if (!key_value || key_value.getValue().size() != 2 || - !key_value.getValue()[0].template isa()) + !mlir::isa(key_value.getValue()[0])) return op.emitOpError() << "each op_attr should be a key-value pair, " "where the key is a string"; } @@ -47,10 +48,10 @@ mlir::LogicalResult VerifyFallbackExecuteOp(OpTy op) { // Verify function attributes. auto op_func_attr_array = op.getOpFuncAttrs().getValue(); for (auto op_attr : op_func_attr_array) { - auto key_value = op_attr.template dyn_cast(); + auto key_value = mlir::dyn_cast(op_attr); if (!key_value || key_value.getValue().size() != 2 || - !key_value.getValue()[0].template isa() || - !key_value.getValue()[1].template isa()) + !mlir::isa(key_value.getValue()[0]) || + !mlir::isa(key_value.getValue()[1])) return op.emitOpError() << "each op_func_attr should be a key-value " "pair, where both the key and the value are " "strings"; @@ -63,11 +64,11 @@ void PrintExecuteOpFuncAttribute(mlir::OpAsmPrinter &p, OpTy op) { auto op_func_attrs = op.getOpFuncAttrs(); if (!op_func_attrs.empty()) { auto print_key_value = [&](mlir::Attribute attr) { - auto key_value = attr.cast().getValue(); + auto key_value = mlir::cast(attr).getValue(); auto key = key_value[0]; auto value = key_value[1]; - p << key.cast().getValue(); + p << mlir::cast(key).getValue(); p << " = "; p << value; }; @@ -84,11 +85,11 @@ void PrintExecuteOpCommon(mlir::OpAsmPrinter &p, OpTy op) { auto op_attrs = op.getOpAttrs(); if (!op_attrs.empty()) { auto print_key_value = [&](mlir::Attribute attr) { - auto key_value = attr.cast().getValue(); + auto key_value = mlir::cast(attr).getValue(); auto key = key_value[0]; auto value = key_value[1]; - p << key.cast().getValue(); + p << mlir::cast(key).getValue(); p << " = "; p << value; }; diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.cc b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.cc index 8083fcac076745..6a429ef275e869 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.cc +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tfrt/core_runtime/opdefs/attributes.h" // from @tf_runtime #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime +#include "tfrt/tensor/opdefs/tensor.h" // from @tf_runtime namespace tfrt { namespace fallback_sync { @@ -50,7 +51,7 @@ FallbackSyncDialect::FallbackSyncDialect(MLIRContext *context) } static Type GetTensorType(Builder *builder) { - return tfrt::t::TensorType::get(builder->getContext()); + return tfrt::tfrt_tensor::TensorType::get(builder->getContext()); } } // namespace fallback_sync diff --git a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.cc b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.cc index 50b1a199e47c17..c63952f55de1b5 100644 --- a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.cc +++ b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.cc @@ -147,7 +147,6 @@ void RuntimeFallbackExecutor::Prepare(llvm::StringRef mlir_input) { pipeline_opts.sink_in_invariant_ops = false; pipeline_opts.cost_threshold = 1024; pipeline_opts.merge_inter_dependent_streams = true; - pipeline_opts.func_use_fallback_tensor = true; mlir::PassManager pm(module->getContext()); pm.addPass(CreateTfToTfrtConversionPass(pipeline_opts)); diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc index d02155c88b7e22..93d50a012a6fed 100644 --- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -42,9 +43,9 @@ namespace { using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; llvm::StringRef ProcessIndexPath(mlir::ArrayAttr index_path) { - if (index_path.size() == 1 && index_path[0].isa()) { + if (index_path.size() == 1 && mlir::isa(index_path[0])) { // TODO(chky): Support cases where index_path is not a single string. - return index_path[0].cast().getValue(); + return mlir::cast(index_path[0]).getValue(); } return ""; } @@ -71,7 +72,7 @@ Status MapFunctionSignaturesFromTFSavedModelMLIR( llvm::function_ref map_fn) { // Create bound inputs for each functions. mlir::SymbolTable symbol_table(module); - tensorflow::Status status = OkStatus(); + tensorflow::Status status = absl::OkStatus(); module.walk([&symbol_table, map_fn, &status](mlir::func::FuncOp func) { // Use the exported name as the function name, and skip non-exported // functions. @@ -92,8 +93,8 @@ Status MapFunctionSignaturesFromTFSavedModelMLIR( if (auto input_index_path = func.getArgAttrOfType( i, kTfSavedModelIndexPathAttr)) { input_names.push_back(ProcessIndexPath(input_index_path)); - auto statusor_spec = - ProcessTensorSpec(func_type.getInput(i).cast()); + auto statusor_spec = ProcessTensorSpec( + mlir::cast(func_type.getInput(i))); if (!statusor_spec.ok()) { status = std::move(statusor_spec).status(); return mlir::WalkResult::interrupt(); @@ -120,8 +121,8 @@ Status MapFunctionSignaturesFromTFSavedModelMLIR( if (auto output_index_path = func.getResultAttrOfType( i, kTfSavedModelIndexPathAttr)) { output_names.push_back(ProcessIndexPath(output_index_path)); - auto statusor_spec = - ProcessTensorSpec(func_type.getResult(i).cast()); + auto statusor_spec = ProcessTensorSpec( + mlir::cast(func_type.getResult(i))); if (!statusor_spec.ok()) { status = std::move(statusor_spec).status(); return mlir::WalkResult::interrupt(); diff --git a/tensorflow/compiler/mlir/tfrt/tests/batch_function_lowering.mlir b/tensorflow/compiler/mlir/tfrt/tests/batch_function_lowering.mlir index 9f6e47567f8c65..6944b477f535e6 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/batch_function_lowering.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/batch_function_lowering.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline="func-use-fallback-tensor=true" %s | FileCheck %s --dump-input=always +// RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline %s | FileCheck %s --dump-input=always func.func private @batched_function(%arg0: tensor<1x3xf32> {tf._user_specified_name = "0"}, %arg1: tensor<*x!tf_type.resource>) -> tensor<1x3xf32> attributes {tf._input_shapes = [#tf_type.shape<1x3>, #tf_type.shape<*>], tf.signature.is_stateful} { %0 = "tf.ReadVariableOp"(%arg1) {device = "/device:CPU:0"} : (tensor<*x!tf_type.resource>) -> tensor<1x3xf32> diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/sink_variable_as_named_array.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/sink_variable_as_named_array.mlir index dec4b733d25b19..39c10c07cdcf35 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/ifrt/sink_variable_as_named_array.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/sink_variable_as_named_array.mlir @@ -9,6 +9,7 @@ // CHECK-NEXT: [[KEY:%.*]], [[FUTURE:%.*]] = "tf.IfrtLoadVariable"([[HANDLE2]]) // CHECK-SAME: device_sharding_config_proto_text = "sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } device_ids: 0 device_ids: 1 " // CHECK-SAME: name = "__y" +// CHECK-SAME: used_by_host = false // CHECK-NEXT: [[RES:%.*]] = "tf.IfrtCall"([[KEY]], %arg0) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [0 : i32]}> // CHECK-SAME: : (tensor, tensor<1x3xf32>) -> tensor<1x1xf32> // CHECK-NEXT: return [[RES]] : tensor<1x1xf32> @@ -29,6 +30,7 @@ module { // CHECK: "tf.VarHandleOp" // CHECK-NOT: [[VARIABLE:%.*]] = "tf.ReadVariableOp" // CHECK-NEXT: [[KEY:%.*]], [[FUTURE:%.*]] = "tf.IfrtLoadVariable" +// CHECK-SAME: used_by_host = true // CHECK-NEXT: "tf.MatMul"(%arg0, [[FUTURE]]) // CHECK-NEXT: [[RES:%.*]] = "tf.IfrtCall"(%arg0, [[KEY]]) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [1 : i32]}> // CHECK-NEXT: return [[RES]] : tensor<1x1xf32> @@ -50,6 +52,7 @@ module { // CHECK: "tf.VarHandleOp" // CHECK-NOT: [[VARIABLE:%.*]] = "tf.ReadVariableOp" // CHECK-NEXT: [[KEY:%.*]], [[FUTURE:%.*]] = "tf.IfrtLoadVariable" +// CHECK-SAME: used_by_host = true // CHECK-NEXT: [[RES:%.*]] = "tf.MatMul"(%arg0, [[FUTURE]]) // CHECK-NEXT: return [[RES]] : tensor<1x1xf32> // diff --git a/tensorflow/compiler/mlir/tfrt/tests/lower_bound_batch_threads.mlir b/tensorflow/compiler/mlir/tfrt/tests/lower_bound_batch_threads.mlir new file mode 100644 index 00000000000000..317d9b3ad9e00a --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/lower_bound_batch_threads.mlir @@ -0,0 +1,53 @@ +// RUN: tf-tfrt-opt -split-input-file -tfrt-lower-bound-batch-threads="tfrt-min-num-batch-threads=2" %s | FileCheck %s --dump-input=always + +// ----- + +// The num_batch_threads is lowered bound to 2 from the original attribute of 1 + +// CHECK-LABEL: func private @batched_function +func.func private @batched_function(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> { + %2 = "tf.Identity"(%arg0) : (tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %2 : tensor<1x3xf32> +} + +// CHECK-LABEL: func @main +func.func @main(%arg0: tensor<1x3xf32>) -> tensor<*xf32> { + // CHECK: "tf.BatchFunction" + // CHECK-SAME: allowed_batch_sizes = [6] + // CHECK-SAME: batch_timeout_micros = 100000 : i64 + // CHECK-SAME: batching_queue = "" + // CHECK-SAME: container = "" + // CHECK-SAME: enable_large_batch_splitting = false + // CHECK-SAME: max_batch_size = 6 : i64 + // CHECK-SAME: max_enqueued_batches = 10 : i64 + // CHECK-SAME: num_batch_threads = 2 : i64 + // CHECK-SAME: shared_name = "batch/" + %1 = "tf.BatchFunction"(%arg0) {allowed_batch_sizes = [6], batch_timeout_micros = 100000 : i64, batching_queue = "", container = "", device = "/device:CPU:0", enable_large_batch_splitting = false, f = @batched_function, max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, num_batch_threads = 1 : i64, operandSegmentSizes = array, shared_name = "batch/"} : (tensor<1x3xf32>) -> tensor<*xf32> + func.return %1 : tensor<*xf32> +} + +// ----- + +// The num_batch_threads remains 3 (the same as the original attribute) + +// CHECK-LABEL: func private @batched_function +func.func private @batched_function(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> { + %2 = "tf.Identity"(%arg0) : (tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %2 : tensor<1x3xf32> +} + +// CHECK-LABEL: func @main +func.func @main(%arg0: tensor<1x3xf32>) -> tensor<*xf32> { + // CHECK: "tf.BatchFunction" + // CHECK-SAME: allowed_batch_sizes = [6] + // CHECK-SAME: batch_timeout_micros = 100000 : i64 + // CHECK-SAME: batching_queue = "" + // CHECK-SAME: container = "" + // CHECK-SAME: enable_large_batch_splitting = false + // CHECK-SAME: max_batch_size = 6 : i64 + // CHECK-SAME: max_enqueued_batches = 10 : i64 + // CHECK-SAME: num_batch_threads = 3 : i64 + // CHECK-SAME: shared_name = "batch/" + %1 = "tf.BatchFunction"(%arg0) {allowed_batch_sizes = [6], batch_timeout_micros = 100000 : i64, batching_queue = "", container = "", device = "/device:CPU:0", enable_large_batch_splitting = false, f = @batched_function, max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, num_batch_threads = 3 : i64, operandSegmentSizes = array, shared_name = "batch/"} : (tensor<1x3xf32>) -> tensor<*xf32> + func.return %1 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir index e1ad0aea205007..24e015734a732f 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir @@ -5,7 +5,7 @@ // CHECK-LABEL: func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> // CHECK-NEXT: [[HANDLE:%.*]] = "tf.VarHandleOp"() // CHECK-NEXT: [[ARRAYKEY:%.*]], [[FURTURE:%.*]] = "tf_mlrt.tf_ifrt_load_variable"([[HANDLE]]) -// CHECK-SAME: {device_sharding_config_proto_text = "sharding { }", name = "__y"} : (tensor>>) -> (tensor, !mlrt.future) +// CHECK-SAME: <{device_sharding_config_proto_text = "sharding { }", name = "__y", used_by_host = true}> : (tensor>>) -> (tensor, !mlrt.future) // CHECK-NEXT: [[TENSOR:%.*]] = "tf_mlrt.tf_await"([[FURTURE]]) : (!mlrt.future) -> tensor<3x1xf32> // CHECK-NEXT: "tf.MatMul"(%arg0, [[TENSOR]]) : (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32> // CHECK-NEXT: "tf.IfrtCall"(%arg0, [[ARRAYKEY]]) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [1 : i32]}> {__tpu_compile_metadata_text = "retvals { sharding { } }"} : (tensor<1x3xf32>, tensor) -> tensor<1x1xf32> @@ -13,7 +13,7 @@ // func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { %0 = "tf.VarHandleOp"() <{container = "", shared_name = "y"}> : () -> tensor>> - %array_key, %tensor = "tf.IfrtLoadVariable"(%0) <{device_sharding_config_proto_text = "sharding { }", name = "__y"}> : (tensor>>) -> (tensor, tensor<3x1xf32>) + %array_key, %tensor = "tf.IfrtLoadVariable"(%0) <{device_sharding_config_proto_text = "sharding { }", name = "__y", used_by_host = true}> : (tensor>>) -> (tensor, tensor<3x1xf32>) %1 = "tf.MatMul"(%arg0, %tensor) : (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32> %2 = "tf.IfrtCall"(%arg0, %array_key) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [1 : i32]}> {__tpu_compile_metadata_text = "retvals { sharding { } }"} : (tensor<1x3xf32>, tensor) -> tensor<1x1xf32> return %2 : tensor<1x1xf32> diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir index 3151daf80ec759..e83e208967b334 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir @@ -470,7 +470,8 @@ func.func @ifrt_load_variable_test() -> () { // CHECK-NEXT: "tf_mlrt.ifrt_load_variable"([[HANDLE]]) // CHECK-SAME: device_sharding_config_proto_text // CHECK-SAME: name = "__variable" - %1, %2 = "tf_mlrt.tf_ifrt_load_variable"(%0) {device_sharding_config_proto_text = "sharding { } device_ids: 0 device_ids: 1 ", name = "__variable", __op_key = 2: i32, device = "/device:CPU:0"} : (tensor>>) -> (tensor, !mlrt.future) + // CHECK-SAME: used_by_host = true + %1, %2 = "tf_mlrt.tf_ifrt_load_variable"(%0) {used_by_host = true, device_sharding_config_proto_text = "sharding { } device_ids: 0 device_ids: 1 ", name = "__variable", __op_key = 2: i32, device = "/device:CPU:0"} : (tensor>>) -> (tensor, !mlrt.future) // CHECK-NEXT: mlrt.await_all_control // CHECK-NEXT: return func.return @@ -490,7 +491,7 @@ func.func @ifrt_restore_variable_test() -> () { %cst_1 = "tf.Const"() {__op_key = 2: i32, value = dense<["y"]> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> // CHECK-NEXT: [[HANDLE:%.*]] = tf_mlrt.executeop %handle = "tf.VarHandleOp"() {__op_key = 3: i32, container = "x", shared_name = "y"} : () -> tensor>> - // CHECK-NEXT: "tf_mlrt.ifrt_restore_variable"([[PREFIX]], [[NAME]], [[SLICE]], [[HANDLE]]) {restored_dtypes = [f32]} + // CHECK-NEXT: "tf_mlrt.ifrt_restore_variable"([[PREFIX]], [[NAME]], [[SLICE]], [[HANDLE]]) <{restored_dtypes = [f32]}> "tf.IfrtRestoreVariableOp"(%cst, %cst_1, %cst_0, %handle) {restored_dtypes = [f32]} : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor>>) -> () // CHECK-NEXT: return func.return diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir index 2b1f5fc9b17a4e..a74d6509a0ed4c 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir @@ -638,3 +638,49 @@ func.func private @tf.NestedWhileRegion_cond(%arg0: tensor, %arg1: tensor } +// ----- + +// Test a while to map_fn conversion is skipped if the tensor list cannot be found in the current function body. + +// CHECK-LABEL: map/while_cond +func.func private @"map/while_cond"(%arg0: tensor, %arg1: tensor, %arg2: tensor>>, %arg3: tensor) -> tensor { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<3> : tensor} : () -> tensor + %0 = "tf.Less"(%arg0, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %1 = "tf.Less"(%arg1, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %2 = "tf.LogicalAnd"(%0, %1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: map/while_body +func.func private @"map/while_body"(%arg0: tensor, %arg1: tensor, %arg2: tensor>>, %arg3: tensor) -> (tensor, tensor, tensor>>, tensor) { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00, 9.000000e+00]> : tensor<9xf32>} : () -> tensor<9xf32> + %cst_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + %cst_2 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<3> : tensor<2xi32>} : () -> tensor<2xi32> + %cst_3 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]> : tensor<9xf32>} : () -> tensor<9xf32> + %cst_4 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<1> : tensor} : () -> tensor + %0 = "tf.AddV2"(%arg0, %cst_4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %1 = "tf.Mul"(%arg3, %cst_3) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor<9xf32>) -> tensor<9xf32> + %2 = "tf.Reshape"(%1, %cst_2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<9xf32>, tensor<2xi32>) -> tensor<3x3xf32> + %3 = "tf.AddV2"(%arg1, %cst_4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %4 = "tf.GatherV2"(%cst_1, %arg1, %cst_0) {batch_dims = 0 : i64, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<3xi32>, tensor, tensor) -> tensor + %5 = "tf.Cast"(%4) {Truncate = false, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor + %6 = "tf.Mul"(%5, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor<9xf32>) -> tensor<9xf32> + %7 = "tf.Reshape"(%6, %cst_2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<9xf32>, tensor<2xi32>) -> tensor<3x3xf32> + %8 = "tf.MatMul"(%2, %7) {device = "/job:localhost/replica:0/task:0/device:CPU:0", transpose_a = false, transpose_b = false} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + %9 = "tf.MatrixDeterminant"(%8) {T = f32, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<3x3xf32>) -> tensor + %10 = "tf.TensorListSetItem"(%arg2, %arg1, %9) {device = "/job:localhost/replica:0/task:0/device:CPU:0", resize_if_index_out_of_bounds = false} : (tensor>>, tensor, tensor) -> tensor>> + return %0, %3, %10, %arg3 : tensor, tensor, tensor>>, tensor +} + +//CHECK-LABEL: @func +func.func @func(%arg0: tensor, %arg1: tensor>>) -> tensor<3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input:0", outputs = "PartitionedCall:0"}} { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %cst_1 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<-1> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<3> : tensor} : () -> tensor + // CHECK-NOT: tf_map_fn + %1:4 = "tf.While"(%cst, %cst, %arg1, %arg0) {_lower_using_switch_merge = true, _num_original_outputs = 6 : i64, _read_only_resource_inputs = [], _xla_propagate_compile_time_consts = true, body = @"map/while_body", cond = @"map/while_cond", device = "/job:localhost/replica:0/task:0/device:CPU:0", is_stateless = true, parallel_iterations = 4 : i64, shape_invariant} : (tensor, tensor, tensor>>, tensor) -> (tensor, tensor, tensor>>, tensor) + %2 = "tf.TensorListStack"(%1#2, %cst_0) {device = "/job:localhost/replica:0/task:0/device:CPU:0", num_elements = 3 : i64} : (tensor>>, tensor<0xi32>) -> tensor<3xf32> + return %2 : tensor<3xf32> +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/attributes.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/attributes.mlir index 77e795b8bf47a1..f8951181427fd6 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/attributes.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/attributes.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -tf-to-tfrt=func-use-fallback-tensor=true %s | FileCheck %s --dump-input=fail +// RUN: tf-tfrt-opt -tf-to-tfrt %s | FileCheck %s --dump-input=fail // _output_shapes and f.* attributes are removed during tf-to-tfrt lowering. // CHECK-LABEL: func @remove_unused_attr diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/basic.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/basic.mlir index 1f8eb1ee6ee01e..57febbbc0ab14c 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/basic.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/basic.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -pass-pipeline='builtin.module(func.func(tf-tensor-device-copy),tfrt-lower-tf-savedmodel{hoist-invariant-ops=true},tf-to-tfrt{func-use-fallback-tensor=true tfrt-cost-threshold=1024 tfrt-merge-inter-dependent-streams=true})' %s | FileCheck %s --dump-input-filter=all +// RUN: tf-tfrt-opt -pass-pipeline='builtin.module(func.func(tf-tensor-device-copy),tfrt-lower-tf-savedmodel{hoist-invariant-ops=true},tf-to-tfrt{tfrt-cost-threshold=1024 tfrt-merge-inter-dependent-streams=true})' %s | FileCheck %s --dump-input-filter=all // CHECK-NOT: tf_saved_model.semantics // CHECK: tfrt.cost_threshold = 1024 diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/const_tensor.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/const_tensor.mlir index b208fe390acc3f..6596d650889384 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/const_tensor.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/const_tensor.mlir @@ -9,12 +9,12 @@ func.func @string_tensor() -> (tensor<0x!tf_type.string>, tensor<7x!tf_type.stri func.return %0, %1 : tensor<0x!tf_type.string>, tensor<7x!tf_type.string> } -// Convert tf.Const to corert.const_dense_tensor only on cpu device +// Convert tf.Const to tfrt_fallback_async.const_dense_tensor only on cpu device // CHECK-LABEL: func @dense_tensor func.func @dense_tensor() -> tensor<4xui64> { - // CHECK: corert.const_dense_tensor dense<[1, 2, 3, 4]> : tensor<4xui64> + // CHECK: tfrt_fallback_async.const_dense_tensor dense<[1, 2, 3, 4]> : tensor<4xui64> %0 = "tf.Const"() {value = dense<[1, 2, 3, 4]> : tensor<4xui64>} : () -> tensor<4xui64> - // CHECK: corert.const_dense_tensor dense<1.000000e+00> : tensor<1xbf16> + // CHECK: tfrt_fallback_async.const_dense_tensor dense<1.000000e+00> : tensor<1xbf16> %1 = "tf.Const"() {device = "/device:CPU:0", value = dense<[1.0]> : tensor<1xbf16>} : () -> tensor<4xbf16> // CHECK: corert.executeop({{.*}}) "tf.Const"() {dtype = ui64, value = dense<[1, 2, 3, 4]> : tensor<4xui64>} : 1 %2 = "tf.Const"() {device = "/device:GPU:0", value = dense<[1, 2, 3, 4]> : tensor<4xui64>} : () -> tensor<4xui64> diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/control_flow.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/control_flow.mlir index dac8c0a71c15fb..ad3232042ca5e7 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/control_flow.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/control_flow.mlir @@ -1,44 +1,44 @@ -// RUN: tf-tfrt-opt -tf-to-tfrt %s | FileCheck %s --dump-input=fail +// RUN: tf-tfrt-opt -tf-to-tfrt="enable-while-parallel-iterations=true" %s | FileCheck %s --dump-input=fail -// CHECK-LABEL: func @cond_false(%arg0: !tfrt.chain, %arg1: !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) +// CHECK-LABEL: func @cond_false(%arg0: !tfrt.chain, %arg1: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) func.func @cond_false(%arg0: tensor) -> tensor { %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<-1> : tensor} : () -> tensor %1 = "tf.Add"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor func.return %1 : tensor } -// CHECK-LABEL: func @cond_true(%arg0: !tfrt.chain, %arg1: !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) +// CHECK-LABEL: func @cond_true(%arg0: !tfrt.chain, %arg1: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) func.func @cond_true(%arg0: tensor) -> tensor { %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<1> : tensor} : () -> tensor %1 = "tf.Add"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor func.return %1 : tensor } -// CHECK-LABEL: func @cond(%arg0: !tfrt.chain, %arg1: !corert.tensorhandle, %arg2: !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) +// CHECK-LABEL: func @cond(%arg0: !tfrt.chain, %arg1: !tfrt_fallback.tf_tensor, %arg2: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) func.func @cond(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: [[cond:%.*]] = tfrt_fallback_async.predicate // CHECK: [[cond_res:%.*]]:2 = tfrt.cond [[cond]] - // CHECK-SAME: @cond_true @cond_false(%arg0, %arg2) : (!tfrt.chain, !corert.tensorhandle) + // CHECK-SAME: @cond_true @cond_false(%arg0, %arg2) : (!tfrt.chain, !tfrt_fallback.tf_tensor) %2 = "tf.If"(%arg0, %arg1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = true} : (tensor, tensor) -> tensor // CHECK: [[out_ch:%.*]] = tfrt.merge.chains [[cond_res]]#0, %arg0 : !tfrt.chain, !tfrt.chain - // CHECK: tfrt.return [[out_ch]], [[cond_res]]#1 : !tfrt.chain, !corert.tensorhandle + // CHECK: tfrt.return [[out_ch]], [[cond_res]]#1 : !tfrt.chain, !tfrt_fallback.tf_tensor func.return %2 : tensor } -// CHECK-LABEL: func @cond_stateful(%arg0: !tfrt.chain, %arg1: !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) +// CHECK-LABEL: func @cond_stateful(%arg0: !tfrt.chain, %arg1: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) func.func @cond_stateful(%arg0: tensor) -> tensor { %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor %1 = "tf.Less"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor // CHECK: [[cond_res:%.*]]:2 = tfrt.cond - // CHECK-SAME: @cond_true @cond_false(%arg0, %arg1) : (!tfrt.chain, !corert.tensorhandle) + // CHECK-SAME: @cond_true @cond_false(%arg0, %arg1) : (!tfrt.chain, !tfrt_fallback.tf_tensor) %2 = "tf.If"(%1, %arg0) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor, tensor) -> tensor // Note: returns %out_op_chain. - // CHECK: tfrt.return [[cond_res]]#0, [[cond_res]]#1 : !tfrt.chain, !corert.tensorhandle + // CHECK: tfrt.return [[cond_res]]#0, [[cond_res]]#1 : !tfrt.chain, !tfrt_fallback.tf_tensor func.return %2 : tensor } // CHECK-LABEL: func @while_cond_lt9 -// CHECK-SAME: ({{%.+}}: !tfrt.chain, {{%.+}}: !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) +// CHECK-SAME: ({{%.+}}: !tfrt.chain, {{%.+}}: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) func.func @while_cond_lt9(%arg0: tensor) -> tensor { %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<9> : tensor} : () -> tensor %1 = "tf.Less"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor @@ -46,7 +46,7 @@ func.func @while_cond_lt9(%arg0: tensor) -> tensor { } // CHECK-LABEL: func @while_body_add2 -// CHECK-SAME: ({{%.+}}: !tfrt.chain, {{%.+}}: !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) +// CHECK-SAME: ({{%.+}}: !tfrt.chain, {{%.+}}: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) func.func @while_body_add2(%arg0: tensor) -> tensor { %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<2> : tensor} : () -> tensor %1 = "tf.Add"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor @@ -54,28 +54,26 @@ func.func @while_body_add2(%arg0: tensor) -> tensor { } // CHECK-LABEL: func @while_test -// CHECK-SAME: ([[ARG0:%.+]]: !tfrt.chain) -> (!tfrt.chain, !corert.tensorhandle) +// CHECK-SAME: ([[ARG0:%.+]]: !tfrt.chain) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) func.func @while_test() -> (tensor) { - // CHECK: [[CONST:%.+]] = corert.const_dense_tensor dense<0> : tensor + // CHECK: [[CONST:%.*]] = tfrt_fallback_async.const_dense_tensor dense<0> : tensor %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor - // CHECK: [[pred_res:%.*]]:2 = tfrt.call @"while_cond_lt9/tfrt_predicate"([[ARG0]], [[CONST]]) : (!tfrt.chain, !corert.tensorhandle) -> (!tfrt.chain, i1) + // CHECK: [[pred_res:%.*]]:2 = tfrt.call @"while_cond_lt9/tfrt_predicate"([[ARG0]], [[CONST]]) : (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, i1) // CHECK: [[while_res:%.]]:2 = tfrt.while [[pred_res]]#1 @"while_body_add2/tfrt_body_1"([[pred_res]]#0, [[CONST]]) - // CHECK-SAME: (!tfrt.chain, !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) + // CHECK-SAME: (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) %1 = "tf.While"(%0) { cond = @while_cond_lt9, body = @while_body_add2, is_stateless = false, parallel_iterations = 1} : (tensor) -> (tensor) // CHECK: [[out_chain:%.*]] = tfrt.merge.chains [[while_res]]#0, [[ARG0]] - // CHECK: tfrt.return [[out_chain]], [[while_res]]#1 : !tfrt.chain, !corert.tensorhandle + // CHECK: tfrt.return [[out_chain]], [[while_res]]#1 : !tfrt.chain, !tfrt_fallback.tf_tensor func.return %1 : tensor } -// CHECK: func @"while_body_add2/tfrt_body_1"([[ch:%.*]]: !tfrt.chain, [[arg:%.*]]: !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle, i1) -// CHECK: [[body_res:%.*]]:2 = tfrt.call @while_body_add2([[ch]], [[arg]]) : (!tfrt.chain, !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) -// CHECK: [[pred_res:%.*]]:2 = tfrt.call @"while_cond_lt9/tfrt_predicate"([[body_res]]#0, [[body_res]]#1) : (!tfrt.chain, !corert.tensorhandle) -> (!tfrt.chain, i1) -// CHECK: tfrt.return [[pred_res]]#0, [[body_res]]#1, [[pred_res]]#1 : !tfrt.chain, !corert.tensorhandle, i1 +// CHECK: func @"while_body_add2/tfrt_body_1"([[ch:%.*]]: !tfrt.chain, [[arg:%.*]]: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor, i1) +// CHECK: [[body_res:%.*]]:2 = tfrt.call @while_body_add2([[ch]], [[arg]]) : (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) +// CHECK: [[pred_res:%.*]]:2 = tfrt.call @"while_cond_lt9/tfrt_predicate"([[body_res]]#0, [[body_res]]#1) : (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, i1) +// CHECK: tfrt.return [[pred_res]]#0, [[body_res]]#1, [[pred_res]]#1 : !tfrt.chain, !tfrt_fallback.tf_tensor, i1 -// CHECK: func @"while_cond_lt9/tfrt_predicate"([[ch:%.*]]: !tfrt.chain, [[arg:%.*]]: !corert.tensorhandle) -> (!tfrt.chain, i1) -// CHECK: [[cond_res:%.*]]:2 = tfrt.call @while_cond_lt9([[ch]], [[arg]]) : (!tfrt.chain, !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) -// CHECK: [[cond:%.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor [[cond_res]]#1 -// CHECK-SAME: (!corert.tensorhandle) -> (!tfrt_fallback.tf_tensor) -// CHECK: [[bool_cond:%.*]] = tfrt_fallback_async.predicate [[cond]] +// CHECK: func @"while_cond_lt9/tfrt_predicate"([[ch:%.*]]: !tfrt.chain, [[arg:%.*]]: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, i1) +// CHECK: [[cond_res:%.*]]:2 = tfrt.call @while_cond_lt9([[ch]], [[arg]]) : (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) +// CHECK: [[bool_cond:%.*]] = tfrt_fallback_async.predicate [[cond_res]]#1 // CHECK: tfrt.return [[cond_res]]#0, [[bool_cond]] : !tfrt.chain, i1 // CHECK-LABEL: func @multi_while_test @@ -83,30 +81,102 @@ func.func @multi_while_test() -> (tensor, tensor) { %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor %1 = "tf.Const"() {device = "/device:CPU:0", value = dense<1> : tensor} : () -> tensor // CHECK: [[pred_0:%.*]]:2 = tfrt.call @"while_cond_lt9/tfrt_predicate" - // CHECK: tfrt.while [[pred_0]]#1 @"while_body_add2/tfrt_body_1" + // CHECK: tfrt.while [[pred_0]]#1 @"while_body_add2/tfrt_body_10" + // CHECK-SAME: parallel_iterations(10) // CHECK: [[pred_1:%.*]]:2 = tfrt.call @"while_cond_lt9/tfrt_predicate" // CHECK: tfrt.while [[pred_1]]#1 @"while_body_add2/tfrt_body_1" - %2 = "tf.While"(%0) { cond = @while_cond_lt9, body = @while_body_add2, is_stateless = false, parallel_iterations = 1} : (tensor) -> (tensor) + // CHECK-SAME: parallel_iterations(1) + %2 = "tf.While"(%0) { cond = @while_cond_lt9, body = @while_body_add2, is_stateless = false, parallel_iterations = 10} : (tensor) -> (tensor) %3 = "tf.While"(%1) { cond = @while_cond_lt9, body = @while_body_add2, is_stateless = false, parallel_iterations = 1} : (tensor) -> (tensor) func.return %2, %3 : tensor, tensor } +func.func @side_effect_while_cond_lt9(%arg: tensor>>) -> tensor { + %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<9> : tensor} : () -> tensor + %1 = "tf.ReadVariableOp"(%arg) {device = "/device:CPU:0", dtype = i32} : (tensor>>) -> tensor + %2 = "tf.Less"(%1, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +func.func @side_effect_while_body_add2(%arg: tensor>>) -> (tensor>>) { + %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<2> : tensor} : () -> tensor + %1 = "tf.ReadVariableOp"(%arg) {device = "/device:CPU:0", dtype = i32} : (tensor>>) -> tensor + %2 = "tf.Add"(%1, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor + "tf.AssignVariableOp"(%arg, %2) {device = "/device:CPU:0"} : (tensor>>, tensor) -> () + func.return %arg : tensor>> +} + +// CHECK-LABEL: func @side_effect_while_test +func.func @side_effect_while_test() -> (tensor) { + %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "c", shared_name = "v"} : () -> tensor>> + // CHECK: [[while_res:%.]]:2 = tfrt.while {{%.*}} @"side_effect_while_body_add2/tfrt_body_1" + // CHECK: [[out_ch:%.*]], [[res:%.*]] = tfrt_fallback_async.executeop.seq([[while_res]]#0) {{.*}} "tf.ReadVariableOp" + %1 = "tf.While"(%0) { cond = @side_effect_while_cond_lt9, body = @side_effect_while_body_add2, is_stateless = false, parallel_iterations = 1} : (tensor>>) -> (tensor>>) + %2 = "tf.ReadVariableOp"(%1) {device = "/device:CPU:0", dtype = i32} : (tensor>>) -> tensor + func.return %2 : tensor +} + +func.func @tensor_array_while_cond(%index: tensor, %size: tensor, %flow_0: tensor, %flow_1: tensor, %handle_0: tensor<2x!tf_type.resource>>, %handle_1: tensor<2x!tf_type.resource>>) -> (tensor) { + %0 = "tf.Less"(%index, %size) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +func.func @tensor_array_while_body(%index: tensor, %size: tensor, %flow_0: tensor, %flow_1: tensor, %handle_0: tensor<2x!tf_type.resource>>, %handle_1: tensor<2x!tf_type.resource>>) -> (tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>>) { + %cst = "tf.Const"() {value = dense<1.1> : tensor<100x512xf32>} : () -> tensor<100x512xf32> + %one = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %x = "tf.TensorArrayReadV3"(%handle_0, %index, %flow_0) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x!tf_type.resource>>, tensor, tensor) -> tensor + %y = "tf.MatMul"(%x, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor<100x512xf32>) -> (tensor) + %flow_1_out = "tf.TensorArrayWriteV3"(%handle_1, %index, %y, %flow_1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x!tf_type.resource>>, tensor, tensor, tensor) -> tensor + %next_index = "tf.AddV2"(%index, %one) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + func.return %next_index, %size, %flow_0, %flow_1_out, %handle_0, %handle_1 : tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>> +} + +// CHECK-LABEL: func @tensor_array_while_test +// CHECK-SAME: ([[in_chain:%.*]]: !tfrt.chain +func.func @tensor_array_while_test(%indices: tensor, %input_0: tensor, %input_1: tensor) -> (tensor, tensor) { + %index = "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor} : () -> (tensor) + %size = "tf.Const"() {device = "/device:CPU:0", value = dense<9> : tensor} : () -> (tensor) + %handle_0, %flow_0 = "tf.TensorArrayV3"(%size) {clear_after_read = true, device = "/job:localhost/replica:0/task:0/device:CPU:0", dtype = f32, dynamic_size = false, element_shape = #tf_type.shape, identical_element_shapes = true, tensor_array_name = "processed_embeddings/bidirectional_rnn/bw/bw/dynamic_rnn/input_0"} : (tensor) -> (tensor<2x!tf_type.resource>>, tensor) + %handle_1, %flow_1 = "tf.TensorArrayV3"(%size) {clear_after_read = true, device = "/job:localhost/replica:0/task:0/device:CPU:0", dtype = f32, dynamic_size = false, element_shape = #tf_type.shape, identical_element_shapes = true, tensor_array_name = "processed_embeddings/bidirectional_rnn/bw/bw/dynamic_rnn/output_0"} : (tensor) -> (tensor<2x!tf_type.resource>>, tensor) + %flow_01 = "tf.TensorArrayScatterV3"(%handle_0, %indices, %input_0, %flow_0) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x!tf_type.resource>>, tensor, tensor, tensor) -> tensor + // CHECK: [[pred_0:%.*]]:2 = tfrt.call @"tensor_array_while_cond/tfrt_predicate"([[in_chain]] + // CHECK: [[while_res_0:%.*]]:7 = tfrt.while {{%.*}} @"tensor_array_while_body/tfrt_body_10"([[pred_0]]#0 + // CHECK-SAME: parallel_iterations(10) + %res_0:6 = "tf.While"(%index, %size, %flow_01, %flow_1, %handle_0, %handle_1) {body = @tensor_array_while_body, cond = @tensor_array_while_cond, device = "", is_stateless = false, parallel_iterations = 10 : i64} : (tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>>) -> (tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>>) + %output_0 = "tf.TensorArrayGatherV3"(%handle_1, %indices, %res_0#3) {device = "/job:localhost/replica:0/task:0/device:CPU:0", element_shape = #tf_type.shape} : (tensor<2x!tf_type.resource>>, tensor, tensor) -> tensor + + %handle_2, %flow_2 = "tf.TensorArrayV3"(%size) {clear_after_read = true, device = "/job:localhost/replica:0/task:0/device:CPU:0", dtype = f32, dynamic_size = false, element_shape = #tf_type.shape, identical_element_shapes = true, tensor_array_name = "processed_embeddings/bidirectional_rnn/bw/bw/dynamic_rnn/input_0"} : (tensor) -> (tensor<2x!tf_type.resource>>, tensor) + %handle_3, %flow_3 = "tf.TensorArrayV3"(%size) {clear_after_read = true, device = "/job:localhost/replica:0/task:0/device:CPU:0", dtype = f32, dynamic_size = false, element_shape = #tf_type.shape, identical_element_shapes = true, tensor_array_name = "processed_embeddings/bidirectional_rnn/bw/bw/dynamic_rnn/output_0"} : (tensor) -> (tensor<2x!tf_type.resource>>, tensor) + %flow_21 = "tf.TensorArrayScatterV3"(%handle_2, %indices, %input_1, %flow_2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x!tf_type.resource>>, tensor, tensor, tensor) -> tensor + // CHECK: [[pred_1:%.*]]:2 = tfrt.call @"tensor_array_while_cond/tfrt_predicate"([[in_chain]] + // CHECK: [[while_res_1:%.*]]:7 = tfrt.while {{%.*}} @"tensor_array_while_body/tfrt_body_10"([[pred_1]]#0 + // CHECK-SAME: parallel_iterations(10) + %res_1:6 = "tf.While"(%index, %size, %flow_21, %flow_3, %handle_2, %handle_3) {body = @tensor_array_while_body, cond = @tensor_array_while_cond, device = "", is_stateless = false, parallel_iterations = 10 : i64} : (tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>>) -> (tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>>) + %output_1 = "tf.TensorArrayGatherV3"(%handle_3, %indices, %res_1#3) {device = "/job:localhost/replica:0/task:0/device:CPU:0", element_shape = #tf_type.shape} : (tensor<2x!tf_type.resource>>, tensor, tensor) -> tensor + func.return %output_0, %output_1 : tensor, tensor +} + +// CHECK: func @"tensor_array_while_body/tfrt_body_10" + func.func @callee(%arg0: tensor) -> (tensor) { func.return %arg0: tensor } // CHECK-LABEL: func @call_test // CHECK-SAME: ([[chain:%.*]]: !tfrt.chain, -func.func @call_test(%arg0: tensor) -> (tensor, tensor) { +func.func @call_test(%arg0: tensor) -> (tensor, tensor, tensor) { %0 = "tf.Add"(%arg0, %arg0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor // CHECK: [[results_0:%.*]]:2 = tfrt.call @callee([[chain]] - // CHECK-SAME: (!tfrt.chain, !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) + // CHECK-SAME: (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) %1 = "tf.StatefulPartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor) -> (tensor) - // CHECK: [[results_1:%.*]]:2 = tfrt.call @callee([[chain]] - // CHECK-SAME: (!tfrt.chain, !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) + // CHECK-NEXT: [[results_1:%.*]]:2 = tfrt.call @callee([[chain]] + // CHECK-SAME: (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) %2 = "tf.PartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor) -> (tensor) - // CHECK: [[results_0]]#1, [[results_1]]#1 - func.return %1, %2 : tensor, tensor + // CHECK-NEXT: [[results_2:%.*]]:2 = tfrt.call @callee([[chain]] + // CHECK-SAME: (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) + %3 = "tf.LegacyCall"(%0) {f = @callee} : (tensor) -> (tensor) + // CHECK: [[results_0]]#1, [[results_1]]#1, [[results_2]]#1 + func.return %1, %2, %3 : tensor, tensor, tensor } func.func @branch0(%arg0: tensor, %arg1: tensor) -> tensor { @@ -120,16 +190,12 @@ func.func @branch1(%arg0: tensor, %arg1: tensor) -> tensor { func.return %1 : tensor } -// CHECK-LABEL: func @case_test( -// CHECK-SAME: arg0: !tfrt.chain, -// CHECK-SAME: arg1: !corert.tensorhandle, -// CHECK-SAME: arg2: !corert.tensorhandle, -// CHECK-SAME: arg3: !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) { +// CHECK-LABEL: func @case_test +// CHECK-SAME: ([[chain:%.*]]: !tfrt.chain, [[tf_idx:%.*]]: !tfrt_fallback.tf_tensor, [[branch_arg0:%.*]]: !tfrt_fallback.tf_tensor, [[branch_arg1:%.*]]: !tfrt_fallback.tf_tensor) func.func @case_test(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK: %[[res_idx:[^ ]+]] = corert.tensorhandle_to_int32 %arg1 - // CHECK: %[[case_out:[^ ]+]]:2 = tfrt.case %[[res_idx]] [@branch0, @branch1](%arg0, %arg2, %arg3) : (!tfrt.chain, !corert.tensorhandle, !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) - // CHECK: %[[out_chain:[^ ]+]] = tfrt.merge.chains %[[case_out]]#0, %arg0 : !tfrt.chain, !tfrt.chain + // CHECK: [[th_idx:%.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle [[tf_idx]] + // CHECK-NEXT: [[idx:%.*]] = corert.tensorhandle_to_int32 [[th_idx]] + // CHECK-NEXT: [[out:%.*]] = tfrt.case [[idx]] [@branch0, @branch1]([[chain]], [[branch_arg0]], [[branch_arg1]]) %0 = "tf.Case"(%arg0, %arg1, %arg2) {_lower_using_switch_merge = true, branches = [@branch0, @branch1], is_stateless = true} : (tensor, tensor, tensor) -> tensor - // CHECK: tfrt.return %[[out_chain]], %[[case_out]]#1 : !tfrt.chain, !corert.tensorhandle func.return %0 : tensor } diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/decompose_resource_op.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/decompose_resource_op.mlir index a872b96a2fd6b4..ff0f0e7dbfd2cd 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/decompose_resource_op.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/decompose_resource_op.mlir @@ -4,13 +4,11 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK-LABEL: func @gather // CHECK-SAME: ([[in_chain:%.*]]: !tfrt.chain -// CHECK-SAME: [[arg0:%.*]]: !corert.tensorhandle, [[arg1:%.*]]: !corert.tensorhandle) -// CHECK: [[const_th:%.*]] = corert.const_dense_tensor -// CHECK-NEXT: [[const:%.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor [[const_th]] {device = "/job:localhost/replica:0/task:0/device:CPU:0"} +// CHECK-SAME: [[arg0:%.*]]: !tfrt_fallback.tf_tensor, [[arg1:%.*]]: !tfrt_fallback.tf_tensor) +// CHECK: [[const:%.*]] = tfrt_fallback_async.const_dense_tensor // CHECK-NEXT: [[out_chain:%.*]], [[value:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key(0) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.ReadVariableOp"({{.*}}) // CHECK-NEXT: [[res:%.*]] = tfrt_fallback_async.executeop key(1) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.GatherV2"([[value]], {{.*}}, [[const]]) -// CHECK-NEXT: [[res_th:%.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle [[res]] {device = "/job:localhost/replica:0/task:0/device:CPU:0"} -// CHECK-NEXT: tfrt.return [[out_chain]], [[res_th]] : !tfrt.chain, !corert.tensorhandle +// CHECK-NEXT: tfrt.return [[out_chain]], [[res]] : !tfrt.chain, !tfrt_fallback.tf_tensor func.func @gather(%indices: tensor, %resource: tensor<*x!tf_type.resource>) -> tensor<*xi32> { %0 = "tf.ResourceGather"(%resource, %indices) {batch_dims = 0 : i64, device = "/device:CPU:0", validate_indices = true}: (tensor<*x!tf_type.resource>, tensor) -> (tensor<*xi32>) diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir index 4c5777c28e2c98..e4dba7e395dbdf 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -tf-to-tfrt=func-use-fallback-tensor=true %s | FileCheck %s --dump-input=fail +// RUN: tf-tfrt-opt -tf-to-tfrt %s | FileCheck %s --dump-input=fail // CHECK-LABEL: func @device_test func.func @device_test( diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback.mlir index b1c8c45d8b7a3f..1b03794356f1f6 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback.mlir @@ -15,7 +15,7 @@ // CHECK-LABEL: func @main // CHECK-SAME: {{.*}} !tfrt.chain -// CHECK-SAME: [[serialized:%.*]]: !corert.tensorhandle +// CHECK-SAME: [[serialized:%.*]]: !tfrt_fallback.tf_tensor func.func @main(%serialized: tensor<32x!tf_type.string>) -> (tensor) attributes {tf.entry_function = {inputs = "input0", outputs = "ParseExample/ParseExampleV2"}} { %dense_default_0 = "tf.Const"() {device = "/device:CPU:0", dtype = f32, value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> %dense_default_1 = "tf.Const"() {device = "/device:CPU:0", dtype = f32, value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> @@ -24,10 +24,8 @@ func.func @main(%serialized: tensor<32x!tf_type.string>) -> (tensor) at %ragged_keys = "tf.Const"() {device = "/device:CPU:0", dtype = !tf_type.string, value = dense<""> : tensor<0x!tf_type.string>} : () -> tensor<0x!tf_type.string> %sparse_keys = "tf.Const"() {device = "/device:CPU:0", dtype = !tf_type.string, value = dense<""> : tensor<2x!tf_type.string>} : () -> tensor<2x!tf_type.string> - // CHECK: [[fallback_serialized:%.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor [[serialized]] - // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:CPU:0" // CHECK: [[outputs:%.*]]:8 = tfrt_fallback_async.executeop key(0) cost({{.*}}) device("/device:CPU:0") "tf.ParseExampleV2" - // CHECK-SAME: ([[fallback_serialized]] + // CHECK-SAME: ([[serialized]] // CHECK-NOT: device // CHECK-SAME: Tdense = [f32, f32] // CHECK-SAME: dense_shapes = [#corert.shape<>, #corert.shape<>] @@ -44,9 +42,7 @@ func.func @main(%serialized: tensor<32x!tf_type.string>) -> (tensor) at } : (tensor<32x!tf_type.string>, tensor<0x!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>, tensor<0x!tf_type.string>, tensor<0xf32>, tensor<0xf32>) -> (tensor, tensor, tensor, tensor, tensor<2xi64>, tensor<2xi64>, tensor<32xf32>, tensor<32xf32>) - // CHECK: [[result:%.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle [[outputs]]#0 - // CHECK-SAME: device = "/device:CPU:0" - // CHECK: tfrt.return {{.*}}, [[result]] + // CHECK: tfrt.return {{.*}}, [[outputs]]#0 func.return %outputs#0 : tensor } diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/func_use_fallback_tensor.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/func_use_fallback_tensor.mlir deleted file mode 100644 index 017efc64c31a50..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/func_use_fallback_tensor.mlir +++ /dev/null @@ -1,207 +0,0 @@ -// RUN: tf-tfrt-opt -tf-to-tfrt="func-use-fallback-tensor=true enable-while-parallel-iterations=true" %s | FileCheck %s --dump-input=fail - -// This file tests the correctness of `func-use-fallback-tensor` option when -// converting from TF to TFRT. Since func op is used by the control flow ops, -// the test cases here should cover the control flow ops. - -// CHECK-LABEL: func @cond_false(%arg0: !tfrt.chain, %arg1: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) -func.func @cond_false(%arg0: tensor) -> tensor { - %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<-1> : tensor} : () -> tensor - %1 = "tf.Add"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @cond_true(%arg0: !tfrt.chain, %arg1: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) -func.func @cond_true(%arg0: tensor) -> tensor { - %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<1> : tensor} : () -> tensor - %1 = "tf.Add"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @cond(%arg0: !tfrt.chain, %arg1: !tfrt_fallback.tf_tensor, %arg2: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) -func.func @cond(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: [[cond:%.*]] = tfrt_fallback_async.predicate - // CHECK: [[cond_res:%.*]]:2 = tfrt.cond [[cond]] - // CHECK-SAME: @cond_true @cond_false(%arg0, %arg2) : (!tfrt.chain, !tfrt_fallback.tf_tensor) - %2 = "tf.If"(%arg0, %arg1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = true} : (tensor, tensor) -> tensor - // CHECK: [[out_ch:%.*]] = tfrt.merge.chains [[cond_res]]#0, %arg0 : !tfrt.chain, !tfrt.chain - // CHECK: tfrt.return [[out_ch]], [[cond_res]]#1 : !tfrt.chain, !tfrt_fallback.tf_tensor - func.return %2 : tensor -} - -// CHECK-LABEL: func @cond_stateful(%arg0: !tfrt.chain, %arg1: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) -func.func @cond_stateful(%arg0: tensor) -> tensor { - %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor - %1 = "tf.Less"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor - // CHECK: [[cond_res:%.*]]:2 = tfrt.cond - // CHECK-SAME: @cond_true @cond_false(%arg0, %arg1) : (!tfrt.chain, !tfrt_fallback.tf_tensor) - %2 = "tf.If"(%1, %arg0) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor, tensor) -> tensor - // Note: returns %out_op_chain. - // CHECK: tfrt.return [[cond_res]]#0, [[cond_res]]#1 : !tfrt.chain, !tfrt_fallback.tf_tensor - func.return %2 : tensor -} - -// CHECK-LABEL: func @while_cond_lt9 -// CHECK-SAME: ({{%.+}}: !tfrt.chain, {{%.+}}: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) -func.func @while_cond_lt9(%arg0: tensor) -> tensor { - %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<9> : tensor} : () -> tensor - %1 = "tf.Less"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @while_body_add2 -// CHECK-SAME: ({{%.+}}: !tfrt.chain, {{%.+}}: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) -func.func @while_body_add2(%arg0: tensor) -> tensor { - %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<2> : tensor} : () -> tensor - %1 = "tf.Add"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @while_test -// CHECK-SAME: ([[ARG0:%.+]]: !tfrt.chain) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) -func.func @while_test() -> (tensor) { - // CHECK: [[CONST_TH:%.*]] = corert.const_dense_tensor dense<0> : tensor - %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor - // CHECK: [[CONST:%.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor [[CONST_TH]] - // CHECK: (!corert.tensorhandle) -> (!tfrt_fallback.tf_tensor) - // CHECK: [[pred_res:%.*]]:2 = tfrt.call @"while_cond_lt9/tfrt_predicate"([[ARG0]], [[CONST]]) : (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, i1) - // CHECK: [[while_res:%.]]:2 = tfrt.while [[pred_res]]#1 @"while_body_add2/tfrt_body_1"([[pred_res]]#0, [[CONST]]) - // CHECK-SAME: (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) - %1 = "tf.While"(%0) { cond = @while_cond_lt9, body = @while_body_add2, is_stateless = false, parallel_iterations = 1} : (tensor) -> (tensor) - // CHECK: [[out_chain:%.*]] = tfrt.merge.chains [[while_res]]#0, [[ARG0]] - // CHECK: tfrt.return [[out_chain]], [[while_res]]#1 : !tfrt.chain, !tfrt_fallback.tf_tensor - func.return %1 : tensor -} -// CHECK: func @"while_body_add2/tfrt_body_1"([[ch:%.*]]: !tfrt.chain, [[arg:%.*]]: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor, i1) -// CHECK: [[body_res:%.*]]:2 = tfrt.call @while_body_add2([[ch]], [[arg]]) : (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) -// CHECK: [[pred_res:%.*]]:2 = tfrt.call @"while_cond_lt9/tfrt_predicate"([[body_res]]#0, [[body_res]]#1) : (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, i1) -// CHECK: tfrt.return [[pred_res]]#0, [[body_res]]#1, [[pred_res]]#1 : !tfrt.chain, !tfrt_fallback.tf_tensor, i1 - -// CHECK: func @"while_cond_lt9/tfrt_predicate"([[ch:%.*]]: !tfrt.chain, [[arg:%.*]]: !tfrt_fallback.tf_tensor) -> (!tfrt.chain, i1) -// CHECK: [[cond_res:%.*]]:2 = tfrt.call @while_cond_lt9([[ch]], [[arg]]) : (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) -// CHECK: [[bool_cond:%.*]] = tfrt_fallback_async.predicate [[cond_res]]#1 -// CHECK: tfrt.return [[cond_res]]#0, [[bool_cond]] : !tfrt.chain, i1 - -// CHECK-LABEL: func @multi_while_test -func.func @multi_while_test() -> (tensor, tensor) { - %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor - %1 = "tf.Const"() {device = "/device:CPU:0", value = dense<1> : tensor} : () -> tensor - // CHECK: [[pred_0:%.*]]:2 = tfrt.call @"while_cond_lt9/tfrt_predicate" - // CHECK: tfrt.while [[pred_0]]#1 @"while_body_add2/tfrt_body_10" - // CHECK-SAME: parallel_iterations(10) - // CHECK: [[pred_1:%.*]]:2 = tfrt.call @"while_cond_lt9/tfrt_predicate" - // CHECK: tfrt.while [[pred_1]]#1 @"while_body_add2/tfrt_body_1" - // CHECK-SAME: parallel_iterations(1) - %2 = "tf.While"(%0) { cond = @while_cond_lt9, body = @while_body_add2, is_stateless = false, parallel_iterations = 10} : (tensor) -> (tensor) - %3 = "tf.While"(%1) { cond = @while_cond_lt9, body = @while_body_add2, is_stateless = false, parallel_iterations = 1} : (tensor) -> (tensor) - func.return %2, %3 : tensor, tensor -} - -func.func @side_effect_while_cond_lt9(%arg: tensor>>) -> tensor { - %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<9> : tensor} : () -> tensor - %1 = "tf.ReadVariableOp"(%arg) {device = "/device:CPU:0", dtype = i32} : (tensor>>) -> tensor - %2 = "tf.Less"(%1, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -func.func @side_effect_while_body_add2(%arg: tensor>>) -> (tensor>>) { - %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<2> : tensor} : () -> tensor - %1 = "tf.ReadVariableOp"(%arg) {device = "/device:CPU:0", dtype = i32} : (tensor>>) -> tensor - %2 = "tf.Add"(%1, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor - "tf.AssignVariableOp"(%arg, %2) {device = "/device:CPU:0"} : (tensor>>, tensor) -> () - func.return %arg : tensor>> -} - -// CHECK-LABEL: func @side_effect_while_test -func.func @side_effect_while_test() -> (tensor) { - %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "c", shared_name = "v"} : () -> tensor>> - // CHECK: [[while_res:%.]]:2 = tfrt.while {{%.*}} @"side_effect_while_body_add2/tfrt_body_1" - // CHECK: [[out_ch:%.*]], [[res:%.*]] = tfrt_fallback_async.executeop.seq([[while_res]]#0) {{.*}} "tf.ReadVariableOp" - %1 = "tf.While"(%0) { cond = @side_effect_while_cond_lt9, body = @side_effect_while_body_add2, is_stateless = false, parallel_iterations = 1} : (tensor>>) -> (tensor>>) - %2 = "tf.ReadVariableOp"(%1) {device = "/device:CPU:0", dtype = i32} : (tensor>>) -> tensor - func.return %2 : tensor -} - -func.func @tensor_array_while_cond(%index: tensor, %size: tensor, %flow_0: tensor, %flow_1: tensor, %handle_0: tensor<2x!tf_type.resource>>, %handle_1: tensor<2x!tf_type.resource>>) -> (tensor) { - %0 = "tf.Less"(%index, %size) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -func.func @tensor_array_while_body(%index: tensor, %size: tensor, %flow_0: tensor, %flow_1: tensor, %handle_0: tensor<2x!tf_type.resource>>, %handle_1: tensor<2x!tf_type.resource>>) -> (tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>>) { - %cst = "tf.Const"() {value = dense<1.1> : tensor<100x512xf32>} : () -> tensor<100x512xf32> - %one = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %x = "tf.TensorArrayReadV3"(%handle_0, %index, %flow_0) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x!tf_type.resource>>, tensor, tensor) -> tensor - %y = "tf.MatMul"(%x, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor<100x512xf32>) -> (tensor) - %flow_1_out = "tf.TensorArrayWriteV3"(%handle_1, %index, %y, %flow_1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x!tf_type.resource>>, tensor, tensor, tensor) -> tensor - %next_index = "tf.AddV2"(%index, %one) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor - func.return %next_index, %size, %flow_0, %flow_1_out, %handle_0, %handle_1 : tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>> -} - -// CHECK-LABEL: func @tensor_array_while_test -// CHECK-SAME: ([[in_chain:%.*]]: !tfrt.chain -func.func @tensor_array_while_test(%indices: tensor, %input_0: tensor, %input_1: tensor) -> (tensor, tensor) { - %index = "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor} : () -> (tensor) - %size = "tf.Const"() {device = "/device:CPU:0", value = dense<9> : tensor} : () -> (tensor) - %handle_0, %flow_0 = "tf.TensorArrayV3"(%size) {clear_after_read = true, device = "/job:localhost/replica:0/task:0/device:CPU:0", dtype = f32, dynamic_size = false, element_shape = #tf_type.shape, identical_element_shapes = true, tensor_array_name = "processed_embeddings/bidirectional_rnn/bw/bw/dynamic_rnn/input_0"} : (tensor) -> (tensor<2x!tf_type.resource>>, tensor) - %handle_1, %flow_1 = "tf.TensorArrayV3"(%size) {clear_after_read = true, device = "/job:localhost/replica:0/task:0/device:CPU:0", dtype = f32, dynamic_size = false, element_shape = #tf_type.shape, identical_element_shapes = true, tensor_array_name = "processed_embeddings/bidirectional_rnn/bw/bw/dynamic_rnn/output_0"} : (tensor) -> (tensor<2x!tf_type.resource>>, tensor) - %flow_01 = "tf.TensorArrayScatterV3"(%handle_0, %indices, %input_0, %flow_0) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x!tf_type.resource>>, tensor, tensor, tensor) -> tensor - // CHECK: [[pred_0:%.*]]:2 = tfrt.call @"tensor_array_while_cond/tfrt_predicate"([[in_chain]] - // CHECK: [[while_res_0:%.*]]:7 = tfrt.while {{%.*}} @"tensor_array_while_body/tfrt_body_10"([[pred_0]]#0 - // CHECK-SAME: parallel_iterations(10) - %res_0:6 = "tf.While"(%index, %size, %flow_01, %flow_1, %handle_0, %handle_1) {body = @tensor_array_while_body, cond = @tensor_array_while_cond, device = "", is_stateless = false, parallel_iterations = 10 : i64} : (tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>>) -> (tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>>) - %output_0 = "tf.TensorArrayGatherV3"(%handle_1, %indices, %res_0#3) {device = "/job:localhost/replica:0/task:0/device:CPU:0", element_shape = #tf_type.shape} : (tensor<2x!tf_type.resource>>, tensor, tensor) -> tensor - - %handle_2, %flow_2 = "tf.TensorArrayV3"(%size) {clear_after_read = true, device = "/job:localhost/replica:0/task:0/device:CPU:0", dtype = f32, dynamic_size = false, element_shape = #tf_type.shape, identical_element_shapes = true, tensor_array_name = "processed_embeddings/bidirectional_rnn/bw/bw/dynamic_rnn/input_0"} : (tensor) -> (tensor<2x!tf_type.resource>>, tensor) - %handle_3, %flow_3 = "tf.TensorArrayV3"(%size) {clear_after_read = true, device = "/job:localhost/replica:0/task:0/device:CPU:0", dtype = f32, dynamic_size = false, element_shape = #tf_type.shape, identical_element_shapes = true, tensor_array_name = "processed_embeddings/bidirectional_rnn/bw/bw/dynamic_rnn/output_0"} : (tensor) -> (tensor<2x!tf_type.resource>>, tensor) - %flow_21 = "tf.TensorArrayScatterV3"(%handle_2, %indices, %input_1, %flow_2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x!tf_type.resource>>, tensor, tensor, tensor) -> tensor - // CHECK: [[pred_1:%.*]]:2 = tfrt.call @"tensor_array_while_cond/tfrt_predicate"([[in_chain]] - // CHECK: [[while_res_1:%.*]]:7 = tfrt.while {{%.*}} @"tensor_array_while_body/tfrt_body_10"([[pred_1]]#0 - // CHECK-SAME: parallel_iterations(10) - %res_1:6 = "tf.While"(%index, %size, %flow_21, %flow_3, %handle_2, %handle_3) {body = @tensor_array_while_body, cond = @tensor_array_while_cond, device = "", is_stateless = false, parallel_iterations = 10 : i64} : (tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>>) -> (tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>>) - %output_1 = "tf.TensorArrayGatherV3"(%handle_3, %indices, %res_1#3) {device = "/job:localhost/replica:0/task:0/device:CPU:0", element_shape = #tf_type.shape} : (tensor<2x!tf_type.resource>>, tensor, tensor) -> tensor - func.return %output_0, %output_1 : tensor, tensor -} - -// CHECK: func @"tensor_array_while_body/tfrt_body_10" - -func.func @callee(%arg0: tensor) -> (tensor) { - func.return %arg0: tensor -} - -// CHECK-LABEL: func @call_test -// CHECK-SAME: ([[chain:%.*]]: !tfrt.chain, -func.func @call_test(%arg0: tensor) -> (tensor, tensor, tensor) { - %0 = "tf.Add"(%arg0, %arg0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor - // CHECK: [[results_0:%.*]]:2 = tfrt.call @callee([[chain]] - // CHECK-SAME: (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) - %1 = "tf.StatefulPartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor) -> (tensor) - // CHECK-NEXT: [[results_1:%.*]]:2 = tfrt.call @callee([[chain]] - // CHECK-SAME: (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) - %2 = "tf.PartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor) -> (tensor) - // CHECK-NEXT: [[results_2:%.*]]:2 = tfrt.call @callee([[chain]] - // CHECK-SAME: (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) - %3 = "tf.LegacyCall"(%0) {f = @callee} : (tensor) -> (tensor) - // CHECK: [[results_0]]#1, [[results_1]]#1, [[results_2]]#1 - func.return %1, %2, %3 : tensor, tensor, tensor -} - -func.func @branch0(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.Add" (%arg0, %arg1) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -func.func @branch1(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.Add" (%arg0, %arg1) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor - %1 = "tf.Add" (%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @case_test -// CHECK-SAME: ([[chain:%.*]]: !tfrt.chain, [[tf_idx:%.*]]: !tfrt_fallback.tf_tensor, [[branch_arg0:%.*]]: !tfrt_fallback.tf_tensor, [[branch_arg1:%.*]]: !tfrt_fallback.tf_tensor) -func.func @case_test(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK: [[th_idx:%.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle [[tf_idx]] - // CHECK-NEXT: [[idx:%.*]] = corert.tensorhandle_to_int32 [[th_idx]] - // CHECK-NEXT: [[out:%.*]] = tfrt.case [[idx]] [@branch0, @branch1]([[chain]], [[branch_arg0]], [[branch_arg1]]) - %0 = "tf.Case"(%arg0, %arg1, %arg2) {_lower_using_switch_merge = true, branches = [@branch0, @branch1], is_stateless = true} : (tensor, tensor, tensor) -> tensor - func.return %0 : tensor -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline.mlir index dd57c72674a3e8..763e188fd4e5c7 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline.mlir @@ -5,15 +5,10 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-LABEL: func @__forward_call_369 // CHECK-SAME: ([[in_chain:%.*]]: !tfrt.chain -// CHECK-SAME: [[arg1_th:%.*]]: !corert.tensorhandle {tf._user_specified_name = "inputs"}, -// CHECK-SAME: [[arg2_th:%.*]]: !corert.tensorhandle, [[arg3_th:%.*]]: !corert.tensorhandle, [[arg4_th:%.*]]: !corert.tensorhandle, [[arg5_th:%.*]]: !corert.tensorhandle) +// CHECK-SAME: [[arg1:%.*]]: !tfrt_fallback.tf_tensor {tf._user_specified_name = "inputs"}, +// CHECK-SAME: [[arg2:%.*]]: !tfrt_fallback.tf_tensor, [[arg3:%.*]]: !tfrt_fallback.tf_tensor, [[arg4:%.*]]: !tfrt_fallback.tf_tensor, [[arg5:%.*]]: !tfrt_fallback.tf_tensor) // CHECK-SAME: -> (!tfrt.chain // CHECK: [[o1:%.*]] = tfrt_fallback_async.const_dense_tensor -// CHECK-NEXT: [[arg1:%.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor [[arg1_th]] {device = "/job:localhost/replica:0/task:0/device:CPU:0" -// CHECK-NEXT: [[arg4:%.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor [[arg4_th]] {device = "/job:localhost/replica:0/task:0/device:CPU:0" -// CHECK-NEXT: [[arg5:%.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor [[arg5_th]] {device = "/job:localhost/replica:0/task:0/device:CPU:0" -// CHECK-NEXT: [[arg2:%.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor [[arg2_th]] {device = "/job:localhost/replica:0/task:0/device:CPU:0" -// CHECK-NEXT: [[arg3:%.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor [[arg3_th]] {device = "/job:localhost/replica:0/task:0/device:CPU:0" // CHECK: [[o2_chain:%.*]], [[o2:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key({{[0-9]+}}) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.ReadVariableOp"([[arg3]]) // CHECK-NEXT: [[o3_chain:%.*]], [[o3:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key({{[0-9]+}}) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.ReadVariableOp"([[arg2]]) // CHECK-NEXT: [[o4_chain:%.*]], [[o4:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key({{[0-9]+}}) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.ReadVariableOp"([[arg5]]) @@ -23,12 +18,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-NEXT: [[o8:%.*]] = tfrt_fallback_async.executeop key({{[0-9]+}}) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.Reshape"([[o7]], [[o1]]) // CHECK-NEXT: [[o9:%.*]] = tfrt_fallback_async.executeop key({{[0-9]+}}) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf._FusedMatMul"([[o8]], [[o5]], [[o4]]) // CHECK-NEXT: [[out_chain:%.*]] = tfrt.merge.chains [[o2_chain]], [[o3_chain]], [[o4_chain]], [[o5_chain]] -// CHECK-NEXT: [[o9_th:%.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle [[o9]] -// CHECK-NEXT: [[o5_th:%.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle [[o5]] -// CHECK-NEXT: [[o8_th:%.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle [[o8]] -// CHECK-NEXT: [[o6_th:%.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle [[o6]] -// CHECK-NEXT: [[o3_th:%.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle [[o3]] -// CHECK-NEXT: tfrt.return [[out_chain]], [[o9_th]], [[o5_th]], [[o8_th]], [[o6_th]], [[arg1_th]], [[o3_th]] : !tfrt.chain, !corert.tensorhandle, !corert.tensorhandle, !corert.tensorhandle, !corert.tensorhandle, !corert.tensorhandle, !corert.tensorhandle +// CHECK-NEXT: tfrt.return [[out_chain]], [[o9]], [[o5]], [[o8]], [[o6]], [[arg1]], [[o3]] : !tfrt.chain, !tfrt_fallback.tf_tensor, !tfrt_fallback.tf_tensor, !tfrt_fallback.tf_tensor, !tfrt_fallback.tf_tensor, !tfrt_fallback.tf_tensor, !tfrt_fallback.tf_tensor func.func @__forward_call_369(%arg0: tensor<16x224x224x3xf32> {tf._user_specified_name = "inputs"}, %arg1: tensor<*x!tf_type.resource>, %arg2: tensor<*x!tf_type.resource>, %arg3: tensor<*x!tf_type.resource>, %arg4: tensor<*x!tf_type.resource>) -> (tensor, tensor<*xf32>, tensor, tensor<16x112x112x?xf32>, tensor<16x224x224x3xf32>, tensor<*xf32>) attributes {tf.entry_function = {control_outputs = "", inputs = "inputs_0,conv1_conv2d_readvariableop_resource,conv1_biasadd_readvariableop_resource,fc1000_matmul_readvariableop_resource,fc1000_biasadd_readvariableop_resource", outputs = "identity_RetVal,fc1000_matmul_readvariableop_RetVal,flatten_reshape_RetVal,relu_RetVal,inputs_RetVal,conv1_conv2d_readvariableop_RetVal"}} { %0:6 = tf_executor.graph { %outputs, %control = tf_executor.island wraps "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource>) -> tensor<*xf32> @@ -62,10 +52,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr } // CHECK-LABEL: func @while_test - // CHECK-SAME: ([[ARG0:%.+]]: !tfrt.chain) -> (!tfrt.chain, !corert.tensorhandle) + // CHECK-SAME: ([[ARG0:%.+]]: !tfrt.chain) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) func.func @while_test() -> (tensor) { // The predicate function should be inlined. - // CHECK: corert.const_dense_tensor dense<0> : tensor // CHECK-DAG: tfrt_fallback_async.const_dense_tensor dense<9> : tensor // CHECK-DAG: tfrt_fallback_async.const_dense_tensor dense<0> : tensor // CHECK-NEXT: tfrt_fallback_async.executeop key({{.*}}) cost({{.*}}) device("/device:CPU:0") "tf.Less" diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_refvar.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_refvar.mlir index e9518003023e96..481b5a421f0aa6 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_refvar.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_refvar.mlir @@ -1,12 +1,11 @@ // RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline %s | FileCheck %s --dump-input=fail // CHECK-LABEL: func @__inference_pruned_131 -// CHECK-SAME: ([[in_chain:%.*]]: !tfrt.chain) -> (!tfrt.chain, !corert.tensorhandle) +// CHECK-SAME: ([[in_chain:%.*]]: !tfrt.chain) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) // CHECK-NEXT: [[o_chain:%.*]], [[o:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key(0) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.VarHandleOp"() // CHECK-NEXT: [[o_chain_0:%.*]], [[o1:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key(1) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.ReadVariableOp"([[o]]) {dtype = f32} : 1 // CHECK-NEXT: [[out_ch:%.*]] = tfrt.merge.chains [[o_chain]], [[o_chain_0]] -// CHECK-NEXT: [[o2:%.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle [[o1]] -// CHECK-NEXT: tfrt.return [[out_ch]], [[o2]] : !tfrt.chain, !corert.tensorhandle +// CHECK-NEXT: tfrt.return [[out_ch]], [[o1]] : !tfrt.chain, !tfrt_fallback.tf_tensor module attributes {tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0"], tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 679 : i32}} { func.func @__inference_pruned_131() -> tensor<*xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "variable", outputs = "identity_retval_RetVal"}} { %0 = tf_executor.graph { diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/whileop.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/whileop.mlir index 5858a015061596..3a1e6b1e8cbc97 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/whileop.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/whileop.mlir @@ -19,7 +19,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-LABEL: func @while_test_remove_unused_results // CHECK: [[pred:%.*]] = tfrt_fallback_async.predicate // CHECK-NEXT: tfrt.while [[pred]] @"[[while_func_prefix:.*]]/tfrt_body_1" - // CHECK-SAME: (!tfrt.chain, !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) + // CHECK-SAME: (!tfrt.chain, !tfrt_fallback.tf_tensor) -> (!tfrt.chain, !tfrt_fallback.tf_tensor) // CHECK-NOT: func.call func.func @while_test_remove_unused_results(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { %0:2 = "tf.While"(%arg0, %arg1) { cond = @while_cond_lt9, body = @while_body_add2, is_stateless = false, parallel_iterations = 1} : (tensor, tensor) -> (tensor, tensor) diff --git a/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir b/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir index 3556f932ac1f2f..4ac2e850ae642f 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -split-input-file -tf-executor-to-tfrt-pipeline="target-gpu=true func-use-fallback-tensor=true" -tfrt-lower-tf-savedmodel=hoist-invariant-ops=true %s | FileCheck %s --dump-input=fail --dump-input-filter=all +// RUN: tf-tfrt-opt -split-input-file -tf-executor-to-tfrt-pipeline="target-gpu=true" -tfrt-lower-tf-savedmodel=hoist-invariant-ops=true %s | FileCheck %s --dump-input=fail --dump-input-filter=all func.func private @xla_func_0(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._XlaMustCompile = true, tf._noinline = true, tf._original_func_name = "should_not_be_used"} { %1 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> diff --git a/tensorflow/compiler/mlir/tfrt/tests/xla_launch_lowering.mlir b/tensorflow/compiler/mlir/tfrt/tests/xla_launch_lowering.mlir index 934166aae198fc..c53b025468d950 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/xla_launch_lowering.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/xla_launch_lowering.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -split-input-file -tf-executor-to-tfrt-pipeline="target-gpu=true use-gpu-compile-and-execute-op=true func-use-fallback-tensor=true" -tfrt-lower-tf-savedmodel=hoist-invariant-ops=true %s | FileCheck %s --dump-input=fail --dump-input-filter=all +// RUN: tf-tfrt-opt -split-input-file -tf-executor-to-tfrt-pipeline="target-gpu=true use-gpu-compile-and-execute-op=true" -tfrt-lower-tf-savedmodel=hoist-invariant-ops=true %s | FileCheck %s --dump-input=fail --dump-input-filter=all func.func private @xla_func_0(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._XlaMustCompile = true, tf._noinline = true, tf._original_func_name = "should_not_be_used"} { %1 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> diff --git a/tensorflow/compiler/mlir/tfrt/transforms/attr_lowering_utils.cc b/tensorflow/compiler/mlir/tfrt/transforms/attr_lowering_utils.cc index efedf36452dc12..a0ae8cb06e45df 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/attr_lowering_utils.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/attr_lowering_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Types.h" +#include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" @@ -37,39 +38,39 @@ mlir::TypeAttr ConvertTypeAttribute(mlir::TypeAttr type_attr, if (IsSupportedTfrtNumericDType(type)) return type_attr; // For TF custom types, we convert it to custom corert types. - if (type.isa()) + if (mlir::isa(type)) return mlir::TypeAttr::get( tfrt::corert::StringType::get(builder.getContext())); - if (type.isa()) + if (mlir::isa(type)) return mlir::TypeAttr::get( tfrt::corert::ResourceType::get(builder.getContext())); - if (type.isa()) + if (mlir::isa(type)) return mlir::TypeAttr::get( tfrt::corert::VariantType::get(builder.getContext())); - if (type.isa()) { + if (mlir::isa(type)) { return mlir::TypeAttr::get( tfrt::corert::Quint8Type::get(builder.getContext())); } - if (type.isa()) { + if (mlir::isa(type)) { return mlir::TypeAttr::get( tfrt::corert::Quint16Type::get(builder.getContext())); } - if (type.isa()) { + if (mlir::isa(type)) { return mlir::TypeAttr::get( tfrt::corert::Qint8Type::get(builder.getContext())); } - if (type.isa()) { + if (mlir::isa(type)) { return mlir::TypeAttr::get( tfrt::corert::Qint16Type::get(builder.getContext())); } - if (type.isa()) { + if (mlir::isa(type)) { return mlir::TypeAttr::get( tfrt::corert::Qint32Type::get(builder.getContext())); } @@ -86,14 +87,15 @@ mlir::Attribute ConvertAttribute(mlir::Attribute attr, mlir::Builder& builder) { // attributes are not supported yet. // Return directly if the attribute is already supported. - if (attr.isa()) + if (mlir::isa(attr)) return attr; // For type attributes, we convert non-standard MLIR types to corresponding // corert types. - if (auto type_attr = attr.dyn_cast()) { - if (auto shape_type = type_attr.getValue().dyn_cast()) { + if (auto type_attr = mlir::dyn_cast(attr)) { + if (auto shape_type = + mlir::dyn_cast(type_attr.getValue())) { if (!shape_type.hasRank()) return tfrt::corert::ShapeAttr::get(builder.getContext()); @@ -106,7 +108,7 @@ mlir::Attribute ConvertAttribute(mlir::Attribute attr, mlir::Builder& builder) { // Convert the attribute to the corresponding format in TFRT dialect if // needed. - if (auto shape_attr = attr.dyn_cast()) { + if (auto shape_attr = mlir::dyn_cast(attr)) { if (!shape_attr.hasRank()) return tfrt::corert::ShapeAttr::get(builder.getContext()); return tfrt::corert::ShapeAttr::get(builder.getContext(), @@ -114,7 +116,7 @@ mlir::Attribute ConvertAttribute(mlir::Attribute attr, mlir::Builder& builder) { } // For arrays, we recursively convert the elements. - if (auto array_attr = attr.dyn_cast()) { + if (auto array_attr = mlir::dyn_cast(attr)) { llvm::SmallVector attrs; attrs.reserve(array_attr.size()); for (auto attr : array_attr) { @@ -140,7 +142,7 @@ bool IsSupportedTfrtNumericDType(mlir::Type type) { type.isUnsignedInteger(64)) return true; - if (auto complex_type = type.dyn_cast()) { + if (auto complex_type = mlir::dyn_cast(type)) { auto element_type = complex_type.getElementType(); if (element_type.isF32() || element_type.isF64()) return true; } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc b/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc index 48d9f755c16c7b..910f7a83a9f7af 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Types.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -46,7 +47,7 @@ CoreRTConverter::CoreRTConverter( addConversion([](tfrt::corert::TensorHandleType type) { return type; }); addConversion([=](mlir::TensorType type) -> std::optional { // Ref types are not supported in both compiler and runtime. - if (type.getElementType().isa()) + if (mlir::isa(type.getElementType())) return std::nullopt; return tensor_handle_type(); }); @@ -74,8 +75,8 @@ mlir::ArrayAttr CoreRTConverter::CreateOpFuncAttrs( auto attr_key = key_and_value.getName(); auto attr_value = key_and_value.getValue(); if (!IsUnusedTfrtAttribute(attr_key) && - attr_value.isa()) { - auto func_attr = attr_value.dyn_cast(); + mlir::isa(attr_value)) { + auto func_attr = mlir::dyn_cast(attr_value); auto converted = CanonicalizeTensorflowFunctionName( symbol_table, func_attr.getValue().str(), use_mlir_func_name); if (!converted) return {}; @@ -126,7 +127,7 @@ std::optional CoreRTConverter::ParseDeviceName( } auto parsed_device_name = - ParseDeviceName(device_attr.cast().getValue()); + ParseDeviceName(mlir::cast(device_attr).getValue()); if (!parsed_device_name) op->emitWarning("failed to parse device name."); return parsed_device_name; } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/cross_device_transfer.cc b/tensorflow/compiler/mlir/tfrt/transforms/cross_device_transfer.cc index 2b1e29c5347096..5f539b8c520e65 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/cross_device_transfer.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/cross_device_transfer.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/core/util/device_name_utils.h" #include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime @@ -81,8 +82,8 @@ static std::string GetDevice(Operation *op) { SmallVector, 4> attrs; execute_op.getOpAttrs(&attrs); for (std::pair entry : attrs) { - if (entry.first == kDeviceAttr && entry.second.isa()) { - device = entry.second.cast().getValue().str(); + if (entry.first == kDeviceAttr && mlir::isa(entry.second)) { + device = mlir::cast(entry.second).getValue().str(); break; } } @@ -94,7 +95,7 @@ static std::string GetDevice(Operation *op) { // Return the device of the given value. static std::string GetDevice(mlir::Value value, func::FuncOp parent_func_op) { std::string device = ""; - if (BlockArgument block_arg = value.dyn_cast()) { + if (BlockArgument block_arg = mlir::dyn_cast(value)) { if (StringAttr device_attr = parent_func_op.getArgAttrOfType( block_arg.getArgNumber(), kTFRTDeviceAttr)) { device = device_attr.getValue().str(); @@ -140,10 +141,10 @@ void CrossDeviceTransferPass::runOnOperation() { for (mlir::Value arg : op->getOperands()) { // Do not transfer non-TensorHandle values. - if (!arg.getType().isa()) continue; + if (!mlir::isa(arg.getType())) continue; // Do not transfer the result of corert.transfer op. - if (OpResult op_result = arg.dyn_cast()) { + if (OpResult op_result = mlir::dyn_cast(arg)) { Operation *defining_op = arg.getDefiningOp(); if (llvm::isa(defining_op)) continue; } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.cc b/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.cc index ef8c2ec38ce64b..77759a631f177f 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h" #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" @@ -31,7 +32,7 @@ FallbackConverter::FallbackConverter(mlir::MLIRContext *context) addConversion([](tfrt::fallback::TFTensorType type) { return type; }); addConversion([=](mlir::TensorType type) -> std::optional { // Ref types are not supported in both compiler and runtime. - if (type.getElementType().isa()) { + if (mlir::isa(type.getElementType())) { return std::nullopt; } @@ -46,9 +47,9 @@ FallbackConverter::FallbackConverter(mlir::MLIRContext *context) mlir::Value ConvertCoreRTTensorHandleToFallbackTensor( mlir::Location loc, llvm::StringRef device, mlir::Value value, mlir::ConversionPatternRewriter &rewriter) { - if (value.getType().isa()) return value; + if (mlir::isa(value.getType())) return value; - if (!value.getType().isa()) return {}; + if (!mlir::isa(value.getType())) return {}; mlir::OpBuilder::InsertionGuard guard(rewriter); @@ -82,9 +83,9 @@ mlir::Value ConvertCoreRTTensorHandleToFallbackTensor( mlir::Value ConvertFallbackTensorToCoreRTTensorHandle( mlir::Location loc, mlir::Value value, mlir::ConversionPatternRewriter &rewriter) { - if (value.getType().isa()) return value; + if (mlir::isa(value.getType())) return value; - if (!value.getType().isa()) return {}; + if (!mlir::isa(value.getType())) return {}; // Use CPU device by default if no device is specified. llvm::StringRef device = GetDefaultCpuDeviceName(); @@ -134,7 +135,7 @@ mlir::LogicalResult ConvertFallbackOperands( llvm::SmallVectorImpl *new_operands, mlir::ConversionPatternRewriter &rewriter) { for (auto operand : operands) { - if (!operand.getType().isa()) { + if (!mlir::isa(operand.getType())) { auto new_operand = ConvertCoreRTTensorHandleToFallbackTensor( op->getLoc(), device, operand, rewriter); if (!new_operand) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index 305195e744932f..b06b37c9e6cb27 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -15,6 +15,7 @@ package_group( "//tensorflow/core/tfrt/saved_model/tests/...", ] + if_google([ "//learning/brain/tfrt/cpp_tests/...", + "//learning/serving/servables/tfrt/...", "//learning/pathways/serving/runtime/...", "//learning/pathways/serving/tests/...", "//learning/brain/tfrt/mlir/mlrt/application/pathways/compiler/...", @@ -124,15 +125,9 @@ cc_library( ":ifrt_constants", ":ifrt_types", "//tensorflow/compiler/jit:xla_cpu_jit", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", - "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_a_m_inc_gen", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_n_z_inc_gen", - "//tensorflow/compiler/mlir/tensorflow:visitor", "//tensorflow/compiler/mlir/tf2xla/api/v2:legalize_tf", - "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:core_cpu_base", @@ -199,6 +194,7 @@ tf_cc_test( ], tags = ["no_oss"], deps = [ + ":ifrt_types", ":tf2hlo", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/tf2xla:xla_helpers", @@ -265,6 +261,7 @@ tf_cc_test( "//tensorflow/core/platform:resource_loader", "//tensorflow/core/tfrt/graph_executor:graph_execution_options", "//tensorflow/core/tfrt/ifrt:ifrt_model_context", + "//tensorflow/core/tfrt/ifrt:ifrt_serving_core_selector", "//tensorflow/core/tfrt/runtime", "//tensorflow/core/tfrt/saved_model:saved_model_testutil", "@com_google_absl//absl/strings", @@ -272,6 +269,7 @@ tf_cc_test( "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", + "@local_tsl//tsl/framework/test_util:mock_serving_device_selector", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/python/ifrt", diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc index ebaf2570bba3f4..93e41027ec594e 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc @@ -88,13 +88,18 @@ CompileAndRegisterIfrtPrograms(absl::string_view model_name, } }); - auto executable = std::make_unique( - model_name, entry_function_name.str(), *std::move(submodule), - ifrt_model_context.GetClient(), &ifrt_model_context.GetThreadPool(), - &ifrt_model_context.GetLoadedVariableRegistry(), - &ifrt_model_context.GetRestoreTensorRegistry(), - ifrt_model_context.GetDeviceMgr(), - ifrt_model_context.GetShapeRepresentationFn()); + TF_ASSIGN_OR_RETURN( + auto executable, + IfrtServingExecutable::Create( + program_id, model_name, entry_function_name.str(), + *std::move(submodule), ifrt_model_context.GetClient(), + &ifrt_model_context.GetThreadPool(), + &ifrt_model_context.GetLoadedVariableRegistry(), + &ifrt_model_context.GetRestoreTensorRegistry(), + ifrt_model_context.checkpoint_loader_queue(), + ifrt_model_context.GetDeviceMgr(), + ifrt_model_context.GetShapeRepresentationFn(), + ifrt_model_context.GetIfrtServingCoreSelector())); // Register the Ifrt program to `ServingExecutableRegistry` so that // the client TF program can invoke them via `IfrtCall` op. diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h index 2407fe7cc3546c..085c70812feaed 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h @@ -31,6 +31,12 @@ class IfrtBackendCompiler : public tensorflow::BackendCompiler { explicit IfrtBackendCompiler(TpuCompiler* tpu_compiler = nullptr) : tpu_compiler_(tpu_compiler) {} + void GetDependentDialects(mlir::DialectRegistry& registry) const override { + if (tpu_compiler_) { + tpu_compiler_->RegisterTPUDialects(®istry); + } + } + // Rewrites the tensorflow graph in MLIR for IFRT serving. The methods // extracts regions for IFRT execution on accelerator (e.g. TPU). absl::Status CompileTensorflow( diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc index 71ba7724de922b..dea849f2a1e3fa 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc @@ -34,8 +34,10 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" #include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" #include "tensorflow/core/tfrt/runtime/runtime.h" #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" +#include "tsl/framework/test_util/mock_serving_device_selector.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" @@ -85,8 +87,11 @@ TEST(IfrtBackendCompilerTest, Basic) { tensorflow::tfrt_stub::ModelRuntimeContext runtime_context( &graph_execution_options, /*export_dir=*/"", &resource_context); + tsl::test_util::MockServingDeviceSelector mock_serving_device_selector; + IfrtServingCoreSelector core_selector(&mock_serving_device_selector); + runtime_context.resource_context().CreateResource( - "IfrtModelContext", client, &GetThreadPool()); + "IfrtModelContext", client, &core_selector, &GetThreadPool()); IfrtBackendCompiler compiler; TF_ASSERT_OK(compiler.CompileTensorflow(runtime_context, mlir_module.get())); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/sink_variable_as_named_array.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/sink_variable_as_named_array.cc index b3bf510003e797..49a7e817ed8f60 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/sink_variable_as_named_array.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/sink_variable_as_named_array.cc @@ -195,6 +195,9 @@ class SinkVariableAsNamedArrayPass // ReadVariableOp. module.walk([&](mlir::TF::ReadVariableOp read_variable_op) { if (!read_variable_op->use_empty()) { + // This variable tensor is used by CPU host. + read_to_load[read_variable_op].setUsedByHost(true); + // Replace CPU use of ReadVariableOp read_variable_op.replaceAllUsesWith( read_to_load[read_variable_op].getTensorFuture()); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/ifrt_cluster.mlir b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/ifrt_cluster.mlir index 1c3ab3703e2384..f93c2532e23d18 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/ifrt_cluster.mlir +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/ifrt_cluster.mlir @@ -1,9 +1,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func.func @main() { - "tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () + "tf_device.cluster_func"() {_replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () func.return } func.func @empty_func() { func.return } -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc index a0b01ba1ffc3f7..0210eb3ed2dc62 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc @@ -66,31 +66,11 @@ namespace tensorflow { namespace ifrt_serving { namespace { static constexpr absl::string_view kEntryFuncName = "main"; +} // namespace -absl::StatusOr GetCompileMetadata( - mlir::func::FuncOp op, absl::Span inputs, - const xla::ifrt::Client& ifrt_client) { - tensorflow::tpu::TPUCompileMetadataProto metadata; - - auto metadata_text_attr = - op->getAttrOfType(kMetadataTextAttrName); - - if (metadata_text_attr && !metadata_text_attr.getValue().empty()) { - // Try __tpu_compile_metadata_text attribute. This only for debugging - // purpose. - VLOG(1) << "Parsing from attribute " << kMetadataTextAttrName - << metadata_text_attr.getValue().str(); - if (!tsl::protobuf::TextFormat::ParseFromString( - metadata_text_attr.getValue().str(), &metadata)) { - return absl::InvalidArgumentError(absl::StrCat( - "Attribute ", kMetadataTextAttrName, ":", - metadata_text_attr.getValue().str(), " cannot be parsed")); - } - } else { - return absl::InvalidArgumentError( - absl::StrCat("Missing ", kMetadataTextAttrName)); - } - +absl::Status UpdateCompileMetadata( + tensorflow::tpu::TPUCompileMetadataProto& metadata, + absl::Span inputs) { VLOG(3) << "TpuCompileMetadata before shape is populated " << metadata; if (metadata.num_replicas() < 1 || metadata.num_cores_per_replica() < 1) { return absl::InternalError( @@ -98,11 +78,6 @@ absl::StatusOr GetCompileMetadata( " and number of cores per replica ", metadata.num_cores_per_replica(), " must be >= 1")); } - if (op.getNumResults() != metadata.retvals_size()) { - return absl::InternalError( - absl::StrCat("Number of retvals mismatched! Expected ", - op.getNumResults(), " got ", metadata.retvals_size())); - } if (metadata.args_size() != inputs.size()) { return absl::InternalError( absl::StrCat("Number of inputs mismatched! Expected ", @@ -125,10 +100,39 @@ absl::StatusOr GetCompileMetadata( // Update shape. *metadata.mutable_args(i)->mutable_shape() = inputs[i].shape.AsProto(); } + return absl::OkStatus(); +} + +absl::StatusOr GetCompileMetadata( + mlir::ModuleOp module, const xla::ifrt::Client& ifrt_client) { + tensorflow::tpu::TPUCompileMetadataProto metadata; + + auto op = module.lookupSymbol(kEntryFuncName); + if (!op) { + return absl::InternalError("Could not find entry function in MLIR Module."); + } + + auto metadata_text_attr = + op->getAttrOfType(kMetadataTextAttrName); + + if (metadata_text_attr && !metadata_text_attr.getValue().empty()) { + // Try __tpu_compile_metadata_text attribute. This only for debugging + // purpose. + VLOG(1) << "Parsing from attribute " << kMetadataTextAttrName + << metadata_text_attr.getValue().str(); + if (!tsl::protobuf::TextFormat::ParseFromString( + metadata_text_attr.getValue().str(), &metadata)) { + return absl::InvalidArgumentError(absl::StrCat( + "Attribute ", kMetadataTextAttrName, ":", + metadata_text_attr.getValue().str(), " cannot be parsed")); + } + } else { + return absl::InvalidArgumentError( + absl::StrCat("Missing ", kMetadataTextAttrName)); + } // Create a default device assignment if one is not given by the model. if (!metadata.has_device_assignment()) { - // TODO(b/316068010): integrate core selection. TF_ASSIGN_OR_RETURN( auto device_assignment, ifrt_client.GetDefaultDeviceAssignment( @@ -142,11 +146,11 @@ absl::StatusOr GetCompileMetadata( return metadata; } -} // namespace absl::StatusOr CompileTfToHlo( mlir::ModuleOp module, absl::Span inputs, absl::string_view entry_function_name, const xla::ifrt::Client& ifrt_client, + const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata, tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn) { if (VLOG_IS_ON(1)) { tensorflow::DumpMlirOpToFile("ifrt_before_bridge_phase2", module); @@ -165,21 +169,6 @@ absl::StatusOr CompileTfToHlo( TF_ASSIGN_OR_RETURN( auto* client, xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform)); - auto entry_fn = module.lookupSymbol(kEntryFuncName); - if (!entry_fn) { - return absl::InternalError("Could not find entry function in MLIR Module."); - } - - if (inputs.size() != entry_fn.getNumArguments()) { - return absl::InternalError( - absl::StrCat("Entry function arguments mismatched! Expected ", - entry_fn.getNumArguments(), " got", inputs.size())); - } - - TF_ASSIGN_OR_RETURN(tensorflow::tpu::TPUCompileMetadataProto compile_metadata, - GetCompileMetadata(entry_fn, inputs, ifrt_client)); - - VLOG(1) << "Compilation metadata: " << compile_metadata; std::vector arg_shapes; for (const auto& input : inputs) { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h index fec9bbb2c740e7..48d7cabdd14286 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h @@ -24,8 +24,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/python/ifrt/client.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" @@ -38,11 +36,19 @@ struct Tf2HloResult { tf2xla::HostComputeMetadata host_compute_metadata; }; +absl::Status UpdateCompileMetadata( + tensorflow::tpu::TPUCompileMetadataProto& metadata, + absl::Span inputs); + +absl::StatusOr GetCompileMetadata( + mlir::ModuleOp module, const xla::ifrt::Client& ifrt_client); + // A class that convert tf module to hlo // TODO(b/304839793): provide wrap persistent compilation cache. absl::StatusOr CompileTfToHlo( mlir::ModuleOp module, absl::Span inputs, absl::string_view entry_function_name, const xla::ifrt::Client& ifrt_client, + const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata, tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn); } // namespace ifrt_serving diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc index 7ee1c450426b20..b201370ea3ae7b 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "mlir/IR/AsmState.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -32,6 +33,7 @@ limitations under the License. #include "mlir/InitAllDialects.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/test_util.h" @@ -93,8 +95,14 @@ TEST(Tf2HloTest, Empty) { TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, xla::ifrt::test_util::GetClient()); - auto result = CompileTfToHlo(mlir_module.get(), {}, "main", *client, - tensorflow::IdentityShapeRepresentationFn()); + TF_ASSERT_OK_AND_ASSIGN( + tensorflow::tpu::TPUCompileMetadataProto compile_metadata, + GetCompileMetadata(mlir_module.get(), *client)); + TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, {})); + + auto result = + CompileTfToHlo(mlir_module.get(), {}, "main", *client, compile_metadata, + tensorflow::IdentityShapeRepresentationFn()); TF_ASSERT_OK(result.status()); } @@ -125,9 +133,15 @@ TEST(Tf2HloTest, Tuple) { std::vector dtype_and_shapes; dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {1, 3}}); dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {3, 1}}); - auto result = - CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", *client, - tensorflow::IdentityShapeRepresentationFn()); + + TF_ASSERT_OK_AND_ASSIGN( + tensorflow::tpu::TPUCompileMetadataProto compile_metadata, + GetCompileMetadata(mlir_module.get(), *client)); + TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); + + auto result = CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", + *client, compile_metadata, + tensorflow::IdentityShapeRepresentationFn()); TF_ASSERT_OK(result.status()); } @@ -157,9 +171,15 @@ TEST(Tf2HloTest, Spmd) { std::vector dtype_and_shapes; dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {4, 64}}); - auto result = - CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", *client, - tensorflow::IdentityShapeRepresentationFn()); + + TF_ASSERT_OK_AND_ASSIGN( + tensorflow::tpu::TPUCompileMetadataProto compile_metadata, + GetCompileMetadata(mlir_module.get(), *client)); + TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); + + auto result = CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", + *client, compile_metadata, + tensorflow::IdentityShapeRepresentationFn()); LOG(INFO) << result->compile_metadata; TF_ASSERT_OK(result.status()); @@ -227,9 +247,15 @@ TEST(Tf2HloTest, UsingDefaultDeviceAssignment) { dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {4, 64}}); dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {64, 10}}); dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {1, 4}}); - auto result = - CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", *client, - tensorflow::IdentityShapeRepresentationFn()); + + TF_ASSERT_OK_AND_ASSIGN( + tensorflow::tpu::TPUCompileMetadataProto compile_metadata, + GetCompileMetadata(mlir_module.get(), *client)); + TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); + + auto result = CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", + *client, compile_metadata, + tensorflow::IdentityShapeRepresentationFn()); LOG(INFO) << result->compile_metadata; TF_ASSERT_OK(result.status()); @@ -323,9 +349,14 @@ TEST(Tf2HloTest, XlaCallHostCallback) { dtype_and_shapes.push_back(DtypeAndShape{DT_INT32, {1}}); dtype_and_shapes.push_back(DtypeAndShape{DT_INT32, {1}}); - auto result = - CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", *client, - tensorflow::IdentityShapeRepresentationFn()); + TF_ASSERT_OK_AND_ASSIGN( + tensorflow::tpu::TPUCompileMetadataProto compile_metadata, + GetCompileMetadata(mlir_module.get(), *client)); + TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); + + auto result = CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", + *client, compile_metadata, + tensorflow::IdentityShapeRepresentationFn()); TF_ASSERT_OK(result.status()); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/insert_tensor_copy.cc b/tensorflow/compiler/mlir/tfrt/transforms/insert_tensor_copy.cc index 2b2cbbcf318d15..d6c87abeedd54a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/insert_tensor_copy.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/insert_tensor_copy.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h" #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" @@ -69,7 +70,7 @@ class InsertFallbackTensorCopy // Process function arguments first. for (auto arg : func_op.getArguments()) { - if (!arg.getType().isa()) continue; + if (!mlir::isa(arg.getType())) continue; InsertFallbackTensorCopyForValue(arg, func_op->getLoc(), builder, stream_analysis); } @@ -91,7 +92,7 @@ class InsertFallbackTensorCopy // Process each result value. for (auto result : op->getResults()) { - if (!result.getType().isa()) continue; + if (!mlir::isa(result.getType())) continue; InsertFallbackTensorCopyForValue(result, op->getLoc(), builder, stream_analysis); } @@ -147,7 +148,7 @@ class InsertFallbackTensorCopy // For each stream, we will create one new value that replaces the uses in // that stream. - assert(value.getType().isa()); + assert(mlir::isa(value.getType())); // The number of results is the number candidate streams. llvm::SmallVector result_types(copies.size(), diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lower_bound_batch_threads.cc b/tensorflow/compiler/mlir/tfrt/transforms/lower_bound_batch_threads.cc new file mode 100644 index 00000000000000..2c2883181e942c --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/lower_bound_batch_threads.cc @@ -0,0 +1,93 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" + +namespace tensorflow { +namespace tfrt_compiler { +namespace { + +class LowerBoundBatchThreadsPass + : public mlir::PassWrapper> { + public: + explicit LowerBoundBatchThreadsPass(uint64_t min_num_batch_threads) + : mlir::PassWrapper>() { + min_num_batch_threads_ = min_num_batch_threads; + } + LowerBoundBatchThreadsPass() + : mlir::PassWrapper>() {} + LowerBoundBatchThreadsPass(const LowerBoundBatchThreadsPass& other) + : mlir::PassWrapper>(other) {} + + LowerBoundBatchThreadsPass& operator=( + const LowerBoundBatchThreadsPass& other) = delete; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerBoundBatchThreadsPass) + + private: + llvm::StringRef getArgument() const final { + return "tfrt-lower-bound-batch-threads"; + } + + llvm::StringRef getDescription() const final { + return "Lower bound batch threads for batch ops."; + } + + void runOnOperation() override { + if (min_num_batch_threads_ > 0) { + mlir::ModuleOp module = getOperation(); + module.walk([&](mlir::TF::BatchFunctionOp batch_op) { + int64_t num_batch_threads = batch_op.getNumBatchThreads(); + num_batch_threads = + std::max(num_batch_threads, min_num_batch_threads_.getValue()); + batch_op.setNumBatchThreads(num_batch_threads); + }); + } + } + + protected: + mlir::Pass::Option min_num_batch_threads_{ + *this, "tfrt-min-num-batch-threads", llvm::cl::init(1), + llvm::cl::desc("Minimum number of batch threads")}; + ; +}; + +} // namespace + +std::unique_ptr> +CreateLowerBoundBatchThreadsPass(int64_t min_num_batch_threads) { + return std::make_unique(min_num_batch_threads); +} + +static mlir::PassRegistration register_pass; + +} // namespace tfrt_compiler +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc b/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc index 17e3d8be95204d..01ae5811b46b9a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc @@ -44,6 +44,7 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" @@ -231,7 +232,7 @@ void FindCalleesRecursiveForOp(const mlir::SymbolTable &symbol_table, llvm::StringSet<> &callees) { for (const auto &named_attr : op->getAttrs()) { if (auto symbol_attr = - named_attr.getValue().dyn_cast()) { + mlir::dyn_cast(named_attr.getValue())) { auto symbol = symbol_attr.getValue(); if (!callees.contains(symbol)) { callees.insert(symbol); @@ -337,7 +338,8 @@ class LowerTFSavedModelPass func_op->removeAttr(kTfSavedModelExportedNamesAttr); for (auto exported_name : exported_names) { auto exported_func_op = func_op.clone(); - exported_func_op.setName(exported_name.cast()); + exported_func_op.setName( + mlir::cast(exported_name)); // If it is a session initializer, we want to maximize parallelism // and do not perform any stream merge, to minimize latency. @@ -631,8 +633,8 @@ class ConvertReferenceVariableToResourceVariablePass mlir::LogicalResult ConvertReferenceVariableToResourceVariable( mlir::TF::VariableV2Op var_op) { - auto tensor_type = - mlir::TF::DropRefType(var_op.getRef().getType()).cast(); + auto tensor_type = mlir::cast( + mlir::TF::DropRefType(var_op.getRef().getType())); llvm::SmallVector identity_ops; llvm::SmallVector assign_ops; diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD index d7fafb49ee6cdd..ed518285828d1a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD @@ -1,7 +1,6 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", "//tensorflow/compiler/mlir/tfrt:__subpackages__", "//tensorflow/core/tfrt:__subpackages__", ], diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc index af932ff5011895..03817906b9772c 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc @@ -59,37 +59,39 @@ absl::StatusOr ConvertTfMlirToBytecode( mlrt::bc::Buffer bytecode_buffer; TF_RETURN_IF_ERROR(ConvertTfMlirToRuntimeExecutable( options, module, - [&bytecode_buffer, &fallback_state, &model_context, module_with_op_keys]( - mlir::PassManager& pm, mlir::ModuleOp module, - const TfrtPipelineOptions& options) { - if (auto* flib_def = model_context.function_library_definition()) { - // Copy the module before exporting as exporting to graph will - // transform the MLIR to TFG dialect. - mlir::OwningOpRef copy(module.clone()); - TF_RETURN_IF_ERROR( - ExportFunctionDefs(*copy, [flib_def](FunctionDef function_def) { - VLOG(1) << absl::StrCat( - "Exporting MLIR function as function_def: ", - // clang-tidy off - function_def.DebugString() - // clang-tidy on - ); + [&bytecode_buffer, &fallback_state, &model_context, + backend_compiler = options.backend_compiler, + module_with_op_keys](mlir::PassManager& pm, mlir::ModuleOp module, + const TfrtPipelineOptions& options) { + if (backend_compiler) { + if (auto* flib_def = model_context.function_library_definition()) { + // Copy the module before exporting as exporting to graph will + // transform the MLIR to TFG dialect. + mlir::OwningOpRef copy(module.clone()); + TF_RETURN_IF_ERROR( + ExportFunctionDefs(*copy, [flib_def](FunctionDef function_def) { + VLOG(1) << absl::StrCat( + "Exporting MLIR function as function_def: ", + // NOLINTNEXTLINE + function_def.DebugString()); - // The TF MLIR compiler may change the function name. Then we - // need to retrieve the original name from the - // _original_func_name attribute. - auto iter = function_def.attr().find("_original_func_name"); - if (iter != function_def.attr().end()) { - function_def.mutable_signature()->set_name(iter->second.s()); - } + // The TF MLIR compiler may change the function name. Then we + // need to retrieve the original name from the + // _original_func_name attribute. + auto iter = function_def.attr().find("_original_func_name"); + if (iter != function_def.attr().end()) { + function_def.mutable_signature()->set_name( + iter->second.s()); + } - const auto& name = function_def.signature().name(); - if (flib_def->Contains(name)) { - TF_RETURN_IF_ERROR(flib_def->RemoveFunction(name)); - } + const auto& name = function_def.signature().name(); + if (flib_def->Contains(name)) { + TF_RETURN_IF_ERROR(flib_def->RemoveFunction(name)); + } - return flib_def->AddFunctionDef(function_def); - })); + return flib_def->AddFunctionDef(function_def); + })); + } } mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc index 350c424636b2f8..0eba87bb7dfcf1 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.h.inc" @@ -349,7 +350,8 @@ class TFIfrtLoadVariableOpConversion auto new_op = rewriter.create( op.getLoc(), result_types, adaptor.getOperands()[0], - op.getDeviceShardingConfigProtoTextAttr(), op.getNameAttr()); + op.getDeviceShardingConfigProtoTextAttr(), op.getNameAttr(), + op.getUsedByHostAttr()); rewriter.replaceOp(op, new_op); return mlir::success(); @@ -380,11 +382,11 @@ class IfrtRestoreVariableOpConversion }; std::optional DecodeLongName(mlir::Location loc) { - if (auto name_loc = loc.dyn_cast()) { + if (auto name_loc = mlir::dyn_cast(loc)) { return name_loc.getName().str(); } - if (auto fused_loc = loc.dyn_cast()) { + if (auto fused_loc = mlir::dyn_cast(loc)) { std::string fused_name; for (auto l : fused_loc.getLocations()) { if (auto n = DecodeLongName(l)) { @@ -462,7 +464,7 @@ class ExecuteOpConversion final : public mlir::ConversionPattern { tensorflow::TensorProto tensor_proto; auto status = ConvertToTensorProto(const_op.getValue(), &tensor_proto); if (!status.ok()) - return const_op.emitError(tsl::NullTerminatedMessage(status)); + return const_op.emitError(absl::StatusMessageAsCStr(status)); rewriter.replaceOpWithNewOp( op, rewriter.getType(), @@ -1027,7 +1029,7 @@ class TfToMlrtConversionPass type_converter_.addConversion( [=](mlir::TensorType type) -> std::optional { // Ref types are not supported in both compiler and runtime. - if (type.getElementType().isa()) + if (mlir::isa(type.getElementType())) return std::nullopt; return tf_mlrt::TFTensorType::get(context); }); @@ -1037,8 +1039,8 @@ class TfToMlrtConversionPass mlir::ValueRange inputs, mlir::Location loc) -> mlir::Value { if (inputs.size() != 1) return mlir::Value(); - if (inputs[0].getType().isa()) { - if (desired_type.isa()) { + if (mlir::isa(inputs[0].getType())) { + if (mlir::isa(desired_type)) { return builder.create(loc, desired_type, inputs[0]); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.cc index a7975c40e1ff48..0bc2a9617b12d6 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.cc @@ -286,6 +286,10 @@ class WhileToMapFnPass for (auto tensor_list_index : loop_info.tensor_list_or_flow_in) { mlir::Operation *tensor_list_or_flow_in_defining_op = while_op.getOperand(tensor_list_index).getDefiningOp(); + if (tensor_list_or_flow_in_defining_op == nullptr) { + return mlir::failure(); + } + mlir::Operation *max_iterations = nullptr; if (loop_info.max_iterations_arg_idx.has_value()) { max_iterations = diff --git a/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc b/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc index e13e8f36b1a436..0e47fad312c7cc 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" @@ -41,7 +42,8 @@ class FoldDeviceIndex : public mlir::OpRewritePattern { int32_t i = 0; mlir::ArrayAttr device_names = op.getDeviceNames(); for (; i < device_names.size(); ++i) { - auto device_name = device_names[i].cast().getValue(); + auto device_name = + mlir::cast(device_names[i]).getValue(); if (device_name == parsed_name.type) break; } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc index a0d7dbb371e375..69bc9370424671 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc @@ -117,6 +117,10 @@ void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( // Merge non-side-effecting tf.If ops if their operands are the same. pm.addPass(tfrt_compiler::CreateMergeTfIfOpsPass()); + // Lower bound on the number of batch threads in `tf.BatchFunction`. + pm.addPass(tfrt_compiler::CreateLowerBoundBatchThreadsPass( + options.min_num_batch_threads)); + // Deduplicate functions invoked by tf.BatchFunction with the same // shared_name pm.addPass( diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.h b/tensorflow/compiler/mlir/tfrt/transforms/passes.h index 6bcbf8dbad317b..e1c848210dd19b 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.h @@ -66,6 +66,10 @@ std::unique_ptr> CreateMergeTfIfOpsPass(); std::unique_ptr> CreateDeduplicateFunctionsInovkedByBatchFunctionPass(); +// Create a pass to lower bound the number of threads in tf.BatchFunction. +std::unique_ptr> +CreateLowerBoundBatchThreadsPass(int64_t min_num_batch_threads); + // Create a pass to fuse the TPU Ops for TFRT. std::unique_ptr> CreateFuseTpuCompileAndExecutePass(); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc b/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc index 848498c68ba71c..8bdb39c913bf75 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -82,7 +83,7 @@ llvm::SmallVector FindValueInCallees( llvm::SmallDenseSet callees; for (const auto &named_attr : caller->getAttrs()) { if (auto symbol_attr = - named_attr.getValue().dyn_cast()) { + mlir::dyn_cast(named_attr.getValue())) { auto symbol = symbol_attr.getValue(); auto callee = symbol_table.lookup(symbol); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc index 693bd78df0b170..f090745e0ae1c4 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc @@ -41,6 +41,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -181,7 +182,7 @@ class GpuCompileAndExecuteOpConversion if (!xla_function) { return op->emitWarning("failed to find 'function' attribute"); } - auto func_attr = xla_function.dyn_cast(); + auto func_attr = mlir::dyn_cast(xla_function); if (!func_attr || func_attr.getValue().empty()) { return op->emitWarning("failed to find a non-empty 'function' attribute"); } @@ -512,7 +513,7 @@ class FallbackConstOpConversion mlir::ConversionPatternRewriter &rewriter) const override { // Some data types are handled separately using a fast path. if (IsSupportedTfrtNumericDType(op.getDtype()) || - op.getDtype().isa()) + mlir::isa(op.getDtype())) return failure(); // For other data types that do not have a fast path (eg. quantized types), @@ -520,7 +521,7 @@ class FallbackConstOpConversion tensorflow::TensorProto tensor_proto; auto status = ConvertToTensorProto(op.getValue(), &tensor_proto); - if (!status.ok()) return op.emitError(tsl::NullTerminatedMessage(status)); + if (!status.ok()) return op.emitError(absl::StatusMessageAsCStr(status)); rewriter.replaceOpWithNewOp( op, rewriter.getType(), @@ -737,11 +738,11 @@ class FallbackBatchFunctionOpConversion // Lower a tf.Const op that creates a string tensor to a native // corert.create_string_tensor op. -class CoreRTConstDenseTensorOpConversion +class FallbackConstDenseTensorOpConversion : public mlir::OpConversionPattern { public: - CoreRTConstDenseTensorOpConversion(mlir::MLIRContext *context, - CoreRTConverter *corert_converter) + FallbackConstDenseTensorOpConversion(mlir::MLIRContext *context, + CoreRTConverter *corert_converter) : mlir::OpConversionPattern(context, kCoreRTBenefit), corert_converter_(*corert_converter) {} @@ -755,9 +756,9 @@ class CoreRTConstDenseTensorOpConversion if (auto parsed_device_name = corert_converter_.ParseDeviceName(op)) if (parsed_device_name->device_type != DEVICE_CPU) return failure(); - auto new_op = rewriter.create( - op.getLoc(), corert_converter_.tensor_handle_type(), - op.getValue().cast()); + auto new_op = rewriter.create( + op.getLoc(), rewriter.getType(), + mlir::cast(op.getValue())); rewriter.replaceOp(op, new_op->getResult(0)); return success(); } @@ -859,21 +860,21 @@ class TFRTFuncOpSignatureConversion // Lower a tf.Const op that creates a string tensor to a native // corert.create_string_tensor op. -class CoreRTConstStringTensorOpConversion +class FallbackConstStringTensorOpConversion : public mlir::OpConversionPattern { public: - CoreRTConstStringTensorOpConversion(mlir::MLIRContext *context, - CoreRTConverter *corert_converter) + FallbackConstStringTensorOpConversion(mlir::MLIRContext *context, + CoreRTConverter *corert_converter) : mlir::OpConversionPattern(context, kCoreRTBenefit), corert_converter_(*corert_converter) {} LogicalResult matchAndRewrite( mlir::TF::ConstOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // NOLINT - if (!op.getDtype().isa()) return failure(); + if (!mlir::isa(op.getDtype())) return failure(); DenseStringElementsAttr attr = - op.getValue().cast(); + mlir::cast(op.getValue()); llvm::SmallVector values; values.reserve(attr.getNumElements()); @@ -889,8 +890,8 @@ class CoreRTConstStringTensorOpConversion for (auto dim : shape) dims.push_back(rewriter.getIntegerAttr(i64_type, dim)); - auto new_op = rewriter.create( - op.getLoc(), corert_converter_.tensor_handle_type(), + auto new_op = rewriter.create( + op.getLoc(), rewriter.getType(), rewriter.getArrayAttr(dims), rewriter.getArrayAttr(values)); rewriter.replaceOp(op, new_op.getResult()); @@ -905,16 +906,11 @@ class CoreRTConstStringTensorOpConversion LogicalResult ConvertFunctionCallOperands( mlir::Operation *op, ValueRange operands, llvm::SmallVectorImpl *new_operands, - mlir::ConversionPatternRewriter &rewriter, bool func_use_fallback_tensor) { - if (func_use_fallback_tensor) { - // TODO(b/182232457): Support other devices. - return tfrt_compiler::ConvertFallbackOperands( - op, tfrt_compiler::GetDefaultCpuDeviceName(), operands, new_operands, - rewriter); - } else { - return tfrt_compiler::ConvertCoreRTOperands(op, operands, new_operands, - rewriter); - } + mlir::ConversionPatternRewriter &rewriter) { + // TODO(b/182232457): Support other devices. + return tfrt_compiler::ConvertFallbackOperands( + op, tfrt_compiler::GetDefaultCpuDeviceName(), operands, new_operands, + rewriter); } // Convert TF call ops (eg. StatefulPartitionedCall) to tfrt.call. @@ -923,12 +919,10 @@ class TFRTCallOpConversion : public mlir::OpConversionPattern { public: TFRTCallOpConversion(mlir::MLIRContext *context, mlir::TypeConverter *type_converter, - CoreRTConverter *corert_converter, - bool func_use_fallback_tensor) + CoreRTConverter *corert_converter) : mlir::OpConversionPattern(context), type_converter_(*type_converter), - corert_converter_(*corert_converter), - func_use_fallback_tensor_(func_use_fallback_tensor) {} + corert_converter_(*corert_converter) {} LogicalResult matchAndRewrite( CallOp op, typename CallOp::Adaptor adaptor, @@ -953,8 +947,7 @@ class TFRTCallOpConversion : public mlir::OpConversionPattern { // operand is !tfrt_fallback.tf_tensor, and it is also used as fallback // tensor inside the callee function. if (mlir::failed(ConvertFunctionCallOperands(op, adaptor.getOperands(), - &new_operands, rewriter, - func_use_fallback_tensor_))) + &new_operands, rewriter))) return failure(); llvm::SmallVector result_types; @@ -982,7 +975,6 @@ class TFRTCallOpConversion : public mlir::OpConversionPattern { private: mlir::TypeConverter &type_converter_; CoreRTConverter &corert_converter_; - bool func_use_fallback_tensor_; }; // Convert func ReturnOp to tfrt.return. @@ -993,11 +985,9 @@ class TFRTReturnOpConversion : public mlir::OpConversionPattern { public: TFRTReturnOpConversion(mlir::MLIRContext *context, - CoreRTConverter *corert_converter, - bool func_use_fallback_tensor) + CoreRTConverter *corert_converter) : mlir::OpConversionPattern(context), - corert_converter_(*corert_converter), - func_use_fallback_tensor_(func_use_fallback_tensor) {} + corert_converter_(*corert_converter) {} LogicalResult matchAndRewrite( mlir::func::ReturnOp op, OpAdaptor adaptor, @@ -1013,8 +1003,7 @@ class TFRTReturnOpConversion new_operands.push_back( corert_converter_.GetLocalSideEffectChain(op, &rewriter)); if (mlir::failed(ConvertFunctionCallOperands(op, adaptor.getOperands(), - &new_operands, rewriter, - func_use_fallback_tensor_))) + &new_operands, rewriter))) return failure(); rewriter.replaceOpWithNewOp(op, new_operands); @@ -1023,7 +1012,6 @@ class TFRTReturnOpConversion private: CoreRTConverter &corert_converter_; - bool func_use_fallback_tensor_; }; // Convert tf.Case op to tfrt.Case. @@ -1038,12 +1026,10 @@ class TFRTCaseOpConversion : public mlir::OpConversionPattern { public: TFRTCaseOpConversion(mlir::MLIRContext *context, mlir::TypeConverter *type_converter, - CoreRTConverter *corert_converter, - bool func_use_fallback_tensor) + CoreRTConverter *corert_converter) : mlir::OpConversionPattern(context), type_converter_(*type_converter), - corert_converter_(*corert_converter), - func_use_fallback_tensor_(func_use_fallback_tensor) {} + corert_converter_(*corert_converter) {} LogicalResult matchAndRewrite( TF::CaseOp op, OpAdaptor adaptor, @@ -1060,14 +1046,14 @@ class TFRTCaseOpConversion : public mlir::OpConversionPattern { llvm::SmallVector branch_operands; branch_operands.push_back( corert_converter_.GetLocalSideEffectChain(op, &rewriter)); - if (mlir::failed(ConvertFunctionCallOperands( - op, adaptor.getOperands().drop_front(), &branch_operands, rewriter, - func_use_fallback_tensor_))) + if (mlir::failed( + ConvertFunctionCallOperands(op, adaptor.getOperands().drop_front(), + &branch_operands, rewriter))) return failure(); mlir::Value index_operand = adaptor.getOperands()[0]; // TODO(b/182233401): Support TF tensor; remove the conversion op here. - if (index_operand.getType().isa()) { + if (mlir::isa(index_operand.getType())) { // TODO(b/182232457): Support other devices. index_operand = rewriter @@ -1079,7 +1065,7 @@ class TFRTCaseOpConversion : public mlir::OpConversionPattern { tfrt_compiler::GetDefaultCpuDeviceName()) .getResult(0); } - if (!index_operand.getType().isa()) + if (!mlir::isa(index_operand.getType())) return op.emitError( "branch index operand is expected to be a TensorHandle."); mlir::Value index_value = @@ -1096,12 +1082,11 @@ class TFRTCaseOpConversion : public mlir::OpConversionPattern { private: mlir::TypeConverter &type_converter_; CoreRTConverter &corert_converter_; - bool func_use_fallback_tensor_; }; static mlir::Value GetPredicate(mlir::Operation *op, mlir::Value cond_operand, mlir::ConversionPatternRewriter &rewriter) { - if (!cond_operand.getType().isa()) { + if (!mlir::isa(cond_operand.getType())) { cond_operand = tfrt_compiler::ConvertCoreRTTensorHandleToFallbackTensor( op->getLoc(), tfrt_compiler::GetDefaultCpuDeviceName(), cond_operand, rewriter); @@ -1119,12 +1104,10 @@ class TFRTCondOpConversion : public mlir::OpConversionPattern { public: TFRTCondOpConversion(mlir::MLIRContext *context, mlir::TypeConverter *type_converter, - CoreRTConverter *corert_converter, - bool func_use_fallback_tensor) + CoreRTConverter *corert_converter) : mlir::OpConversionPattern(context), type_converter_(*type_converter), - corert_converter_(*corert_converter), - func_use_fallback_tensor_(func_use_fallback_tensor) {} + corert_converter_(*corert_converter) {} mlir::LogicalResult matchAndRewrite( mlir::TF::IfOp op, OpAdaptor adaptor, @@ -1150,8 +1133,7 @@ class TFRTCondOpConversion : public mlir::OpConversionPattern { corert_converter_.GetLocalSideEffectChain(op, &rewriter)); if (mlir::failed(ConvertFunctionCallOperands( - op, adaptor.getOperands().drop_front(), &new_operands, rewriter, - func_use_fallback_tensor_))) + op, adaptor.getOperands().drop_front(), &new_operands, rewriter))) return failure(); auto new_op = rewriter.create( @@ -1174,7 +1156,6 @@ class TFRTCondOpConversion : public mlir::OpConversionPattern { private: mlir::TypeConverter &type_converter_; CoreRTConverter &corert_converter_; - bool func_use_fallback_tensor_; }; // Convert TF WhileOp to tfrt.while. tfrt.while use a boolean condition and has @@ -1219,14 +1200,12 @@ class TFRTWhileOpConversion mlir::SymbolTable *symbol_table, const tfrt_compiler::TensorArraySideEffectAnalysis *tensor_array_side_effect_analysis, - bool func_use_fallback_tensor, bool enable_while_parallel_iterations) : mlir::OpConversionPattern(context), type_converter_(*type_converter), corert_converter_(*corert_converter), symbol_table_(*symbol_table), tensor_array_side_effect_analysis_(*tensor_array_side_effect_analysis), - func_use_fallback_tensor_(func_use_fallback_tensor), enable_while_parallel_iterations_(enable_while_parallel_iterations) {} mlir::LogicalResult matchAndRewrite( @@ -1248,8 +1227,7 @@ class TFRTWhileOpConversion // specified in the option. llvm::SmallVector new_operands; if (mlir::failed(ConvertFunctionCallOperands(op, adaptor.getOperands(), - &new_operands, rewriter, - func_use_fallback_tensor_))) + &new_operands, rewriter))) return failure(); // Create the predicate function that calls the original cond function and @@ -1328,7 +1306,6 @@ class TFRTWhileOpConversion mlir::SymbolTable &symbol_table_; const tfrt_compiler::TensorArraySideEffectAnalysis &tensor_array_side_effect_analysis_; - bool func_use_fallback_tensor_; bool enable_while_parallel_iterations_; }; @@ -1518,9 +1495,8 @@ void PopulateTFToTFRTConversionPatterns( const tfrt_compiler::CostAnalysis *cost_analysis, const tfrt_compiler::TensorArraySideEffectAnalysis *tensor_array_side_effect_analysis, - bool func_use_fallback_tensor, bool enable_while_parallel_iterations, - bool tpu_lower_to_fallback, bool target_tpurt, - bool use_gpu_compile_and_execute_op) { + bool enable_while_parallel_iterations, bool tpu_lower_to_fallback, + bool target_tpurt, bool use_gpu_compile_and_execute_op) { // By default, we lower all TF ops to fallback ops. patterns->add( context, corert_converter, fallback_converter, symbol_table, @@ -1534,23 +1510,17 @@ void PopulateTFToTFRTConversionPatterns( // For control flow ops, we handle them according to the option. mlir::TypeConverter *func_type_converter; - if (func_use_fallback_tensor) { - func_type_converter = fallback_converter; - } else { - func_type_converter = corert_converter; - } + func_type_converter = fallback_converter; patterns->add(context, func_type_converter); - patterns->add(context, corert_converter, - func_use_fallback_tensor); + patterns->add(context, corert_converter); patterns->add( context, func_type_converter, corert_converter, symbol_table, - tensor_array_side_effect_analysis, func_use_fallback_tensor, - enable_while_parallel_iterations); + tensor_array_side_effect_analysis, enable_while_parallel_iterations); patterns->add, TFRTCallOpConversion, TFRTCallOpConversion, TFRTCaseOpConversion, TFRTCondOpConversion>( - context, func_type_converter, corert_converter, func_use_fallback_tensor); + context, func_type_converter, corert_converter); // For tf.BatchFunction, we need a special fallback op to batch a BEF // function. @@ -1562,8 +1532,9 @@ void PopulateTFToTFRTConversionPatterns( // Here we use specialized patterns for tf.Const on CPU as it is incorrect to // use ExecuteOp pattern to convert string tensor attribute. - patterns->add(context, corert_converter); + patterns->add(context, + corert_converter); } // Lower TF dialect MLIR to TFRT dialect. @@ -1598,7 +1569,6 @@ class TfToTfrtConversionPass tpu_allow_unpadded_batch_ = options.tpu_allow_unpadded_batch; cost_threshold_ = options.cost_threshold; merge_inter_dependent_streams_ = options.merge_inter_dependent_streams; - func_use_fallback_tensor_ = options.func_use_fallback_tensor; enable_while_parallel_iterations_ = options.enable_while_parallel_iterations; target_gpu_ = options.target_gpu; @@ -1633,19 +1603,15 @@ class TfToTfrtConversionPass } mlir::TypeConverter *func_type_converter; - if (func_use_fallback_tensor_) { - func_type_converter = &fallback_converter; - } else { - func_type_converter = &corert_converter; - } + func_type_converter = &fallback_converter; SetUpTFToTFRTConversionLegality(&target, func_type_converter, corert_converter.chain_type()); PopulateTFToTFRTConversionPatterns( &context, &patterns, &corert_converter, &fallback_converter, &symbol_table, &cost_analysis, &tensor_array_side_effect_analysis, - func_use_fallback_tensor_, enable_while_parallel_iterations_, - tpu_lower_to_fallback_, target_tpurt_, use_gpu_compile_and_execute_op_); + enable_while_parallel_iterations_, tpu_lower_to_fallback_, + target_tpurt_, use_gpu_compile_and_execute_op_); return mlir::applyPartialConversion(func, target, std::move(patterns)); } @@ -1721,7 +1687,7 @@ class TfToTfrtConversionPass auto return_op = llvm::cast(block.getTerminator()); auto chain = return_op->getOperand(0); - assert(chain.getType().isa()); + assert(mlir::isa(chain.getType())); dangling_values.push_back(chain); mlir::OpBuilder builder(return_op); @@ -1857,13 +1823,6 @@ class TfToTfrtConversionPass "preferred to be merged for inline execution."), llvm::cl::init(false)}; - Option func_use_fallback_tensor_{ - *this, "func-use-fallback-tensor", - llvm::cl::desc( - "If true, use TF tensor as input/output types in func (and other " - "control flow) ops."), - llvm::cl::init(false)}; - Option enable_while_parallel_iterations_{ *this, "enable-while-parallel-iterations", llvm::cl::desc("If true, tf.While op will be parallelized. This is " diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h index b1549838b11ce5..f893749011aca0 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_PIPELINE_OPTIONS_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_PIPELINE_OPTIONS_H_ +#include #include #include "llvm/Support/CommandLine.h" @@ -107,13 +108,6 @@ struct TfrtPipelineOptions llvm::cl::desc("If true, gpurt.compile_and_execute is used for GPU"), llvm::cl::init(false)}; - Option func_use_fallback_tensor{ - *this, "func-use-fallback-tensor", - llvm::cl::desc( - "If true, use TF tensor as input/output types in func (and other " - "control flow) ops."), - llvm::cl::init(false)}; - Option enable_while_parallel_iterations{ *this, "enable-while-parallel-iterations", llvm::cl::desc("If true, tf.While op will be parallelized. This is " @@ -144,6 +138,10 @@ struct TfrtPipelineOptions "cheap, and then whether it can be executed inline."), llvm::cl::init(1)}; + Option min_num_batch_threads{ + *this, "tfrt-min-num-batch-threads", + llvm::cl::desc("The minimum number of batch threads"), llvm::cl::init(1)}; + Option merge_inter_dependent_streams{ *this, "tfrt-merge-inter-dependent-streams", llvm::cl::desc("If true, streams with inter data depenedencies will be " diff --git a/tensorflow/compiler/mlir/tfrt/transforms/utils.cc b/tensorflow/compiler/mlir/tfrt/transforms/utils.cc index 9b602babeafe22..711438f21d13f9 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/utils.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/utils.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime @@ -34,7 +35,7 @@ limitations under the License. namespace tensorflow { bool IsResourceArgument(mlir::Value value) { - auto arg = value.dyn_cast(); + auto arg = mlir::dyn_cast(value); if (!arg) return false; auto func = llvm::cast(arg.getOwner()->getParentOp()); @@ -44,7 +45,7 @@ bool IsResourceArgument(mlir::Value value) { bool IsResultVariable(const mlir::Value &original_operand, const mlir::Value &operand) { - if (original_operand.isa()) { + if (mlir::isa(original_operand)) { auto defining_op = original_operand.getDefiningOp(); // TODO(b/174753886): When device assignment is properly done, we @@ -99,7 +100,8 @@ bool IsSessionInitializer(mlir::func::FuncOp op) { if (!session_initializer_op) return false; for (auto sym_ref : session_initializer_op.getInitializers()) { - if (op.getSymName() == sym_ref.cast().getValue()) + if (op.getSymName() == + mlir::cast(sym_ref).getValue()) return true; } diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index 3cf8be9c90cb62..77e3c687f0c02a 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -123,7 +123,7 @@ absl::StatusOr> ExportXlaFunctions( func_op->walk([&](mlir::Operation* op) { for (const mlir::NamedAttribute& attr : op->getAttrs()) { if (const auto sym = - attr.getValue().dyn_cast()) { + mlir::dyn_cast(attr.getValue())) { mlir::Operation* func = mlir::SymbolTable::lookupNearestSymbolFrom(op, sym); if (func) { @@ -342,10 +342,11 @@ std::unique_ptr GetTfrtPipelineOptions( pipeline_options->hoist_invariant_ops = options.hoist_invariant_ops; pipeline_options->fuse_get_resource_ops_in_hoisting = options.fuse_get_resource_ops_in_hoisting; - pipeline_options->func_use_fallback_tensor = true; pipeline_options->enable_while_parallel_iterations = options.enable_while_parallel_iterations; pipeline_options->cost_threshold = options.cost_threshold; + pipeline_options->min_num_batch_threads = options.min_num_batch_threads; + pipeline_options->merge_inter_dependent_streams = options.merge_inter_dependent_streams; diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD index bc887cdfc966f9..cb517d1039711f 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD @@ -7,7 +7,6 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ # copybara:uncomment "//learning/brain/experimental/tfrt:__subpackages__", - # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", # copybara:uncomment "//smartass/brain/ops/tfrt_kernels:__subpackages__", "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:__subpackages__", "//tensorflow/core/tfrt:__subpackages__", @@ -29,6 +28,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -43,6 +43,7 @@ tf_cc_test( "@com_google_googletest//:gtest_main", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:resource_loader", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc index 98cb26acdba8fa..06606c6fff345e 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc @@ -29,6 +29,7 @@ limitations under the License. #include "llvm/ADT/TypeSwitch.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" namespace mlrt { @@ -37,8 +38,8 @@ namespace { // LINT.IfChange(mlrt_attributes) bool CanBeInlined(mlir::Attribute attr, absl::string_view data) { // FlatSymbolRefAttr is a special case as we are emitting it as integer. - return attr.isa() && + return mlir::isa( + attr) && data.size() <= sizeof(uint32_t); } // LINT.ThenChange(../../../../../core/tfrt/mlrt/interpreter/attribute_span.h:mlrt_attributes) @@ -64,7 +65,7 @@ std::optional EncodeListOfInteger(mlir::ArrayAttr array) { mlir::Type type; for (int i = 0; i < array.size(); ++i) { - if (auto integer_attr = array[i].dyn_cast()) { + if (auto integer_attr = mlir::dyn_cast(array[i])) { if (type && integer_attr.getType() != type) return std::nullopt; type = integer_attr.getType(); llvm::APInt value = integer_attr.getValue(); @@ -85,7 +86,7 @@ std::optional EncodeListOfSymbolRef( auto ctor = bc::New>(&allocator, array.size()); for (int i = 0; i < array.size(); ++i) { - if (auto symbol_ref = array[i].dyn_cast()) { + if (auto symbol_ref = mlir::dyn_cast(array[i])) { ctor.ConstructAt(i, module_context.GetFunctionId(symbol_ref.getValue())); } else { return std::nullopt; @@ -117,7 +118,7 @@ std::optional EncodeListOfString(mlir::ArrayAttr array) { auto ctor = bc::New>(&allocator, array.size()); for (int i = 0; i < array.size(); ++i) { - if (auto string_attr = array[i].dyn_cast()) { + if (auto string_attr = mlir::dyn_cast(array[i])) { ctor.ConstructAt(i, string_attr.getValue().str()); } else { return std::nullopt; diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc index 8214d1d6deb3b3..07f1fbfdb0c0c1 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" #include "tensorflow/core/tfrt/mlrt/interpreter/attribute_span.h" #include "tsl/platform/resource_loader.h" @@ -299,13 +300,13 @@ class CustomDense { absl::StatusOr EncodeCustomDense(const ModuleEmitterContext&, mlir::Attribute attr) { - auto dense_int_attr = attr.dyn_cast(); + auto dense_int_attr = mlir::dyn_cast(attr); if (!dense_int_attr) return absl::InvalidArgumentError( "The element of the custom dense attribute must be an integer."); - if (dense_int_attr.getElementType().cast().getWidth() != - 32) { + if (mlir::cast(dense_int_attr.getElementType()) + .getWidth() != 32) { return absl::InvalidArgumentError( "The element of the custom dense attribute must be an i32 integer."); } diff --git a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc index cf4950d5edbbf3..aadacd8563b7ea 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc @@ -19,8 +19,6 @@ limitations under the License. #include #include -#include "absl/strings/str_join.h" - namespace tensorflow { std::ostream& operator<<(std::ostream& os, @@ -40,8 +38,7 @@ std::ostream& operator<<(std::ostream& os, } std::ostream& operator<<(std::ostream& os, const TfrtCompileOptions& options) { - return os << "{" - << "variable_device = " << options.variable_device + return os << "{" << "variable_device = " << options.variable_device << ", default_device = " << options.default_device << ", enable_optimizer = " << options.enable_optimizer << ", enable_grappler = " << options.enable_grappler @@ -58,6 +55,7 @@ std::ostream& operator<<(std::ostream& os, const TfrtCompileOptions& options) { << ", enable_while_parallel_iterations = " << options.enable_while_parallel_iterations << ", cost_threshold = " << options.cost_threshold + << ", min_num_batch_threads = " << options.min_num_batch_threads << ", merge_inter_dependent_streams = " << options.merge_inter_dependent_streams << ", decompose_resource_ops = " << options.decompose_resource_ops diff --git a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h index 11790e9fa438a0..db50b062b8a209 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h +++ b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h @@ -136,6 +136,13 @@ struct TfrtCompileOptions { // expensive. uint64_t cost_threshold = 1; + // The minimum number of batch threads. This number provides a lower bound on + // the number of batch threads on top of what is specified in the model. If + // the number of batch threads is too small (e.g. smaller than the number of + // parallel hardware accelerator available), it can lead to under utilization + // of resources. + int64_t min_num_batch_threads = 1; + // If true, streams with inter data depenedencies will be preferred to be // merged for inline execution. bool merge_inter_dependent_streams = true; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 1069f3fd172411..e9dd7ef5f5e505 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -159,6 +159,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Support", "@local_xla//xla/mlir_hlo:all_passes", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", "@stablehlo//:register", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index 42d679c35d0173..2862b79475a930 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -92,6 +92,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc index a3b8c07cc1bb66..ee295c19335ff5 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_status.cc.inc" // Generated dialect definitions. @@ -61,11 +62,11 @@ Type TFFrameworkDialect::parseType(DialectAsmParser &parser) const { /// Print a type registered to this dialect. void TFFrameworkDialect::printType(Type type, DialectAsmPrinter &os) const { - if (type.isa()) { + if (mlir::isa(type)) { os << "op_kernel_context"; return; } - if (type.isa()) { + if (mlir::isa(type)) { os << "jit_callable"; return; } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index 277511fed098e0..404d134223ab48 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -52,6 +52,7 @@ limitations under the License. #include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" // from @llvm-project @@ -91,7 +92,7 @@ bool IsSmallAlloc(Value alloc) { constexpr unsigned kMaximumSizeInBytes = 64; constexpr unsigned kMaxRankOfAllocatedMemRef = 1; - auto type = alloc.getType().dyn_cast(); + auto type = mlir::dyn_cast(alloc.getType()); if (!type || !alloc.getDefiningOp()) return false; if (!type.hasStaticShape()) { // Check if the dynamic shape dimension of the alloc is produced by RankOp @@ -176,8 +177,6 @@ Status LowerHlotoLoops(mlir::ModuleOp module, // Transform HLO operations to LinAlg and standard. pm.addNestedPass(::mlir::mhlo::createLegalizeHloToLinalgPass()); pm.addPass(::mlir::mhlo::createLegalizeToArithmeticPass()); - pm.addNestedPass( - mlir::mhlo::createLegalizeHloShapeOpsToStandardPass()); // Remove the remaining references to unsigned types after all HLO compute // operations were converted. diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2.mlir index 686b34e0d138db..c2a9404ebb926b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2.mlir @@ -46,7 +46,7 @@ func.func @AddV2(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> at %cast = tensor.cast %23 : tensor to tensor<*xf32> scf.yield %cast : tensor<*xf32> } else { - %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %19:2 = mhlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor %20 = shape.rank %19#0 : tensor -> index %21 = shape.rank %19#1 : tensor -> index %22 = arith.cmpi sgt, %20, %21 : index diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2_unsigned.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2_unsigned.mlir index f38a2dca1bc8cd..a83f91663951de 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2_unsigned.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2_unsigned.mlir @@ -46,7 +46,7 @@ func.func @AddV2(%arg0: tensor<*xui32>, %arg1: tensor<*xui32>) -> tensor<*xui32> %cast = tensor.cast %23 : tensor to tensor<*xui32> scf.yield %cast : tensor<*xui32> } else { - %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %19:2 = mhlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor %20 = shape.rank %19#0 : tensor -> index %21 = shape.rank %19#1 : tensor -> index %22 = arith.cmpi sgt, %20, %21 : index diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/minimum.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/minimum.mlir index 1facc06ee500e9..dc41eee4404837 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/minimum.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/minimum.mlir @@ -46,7 +46,7 @@ func.func @Minimum_GPU_DT_UINT32_DT_UINT32(%arg0: tensor<*xui32>, %arg1: tensor< %cast = tensor.cast %23 : tensor to tensor<*xui32> scf.yield %cast : tensor<*xui32> } else { - %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %19:2 = mhlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor %20 = shape.rank %19#0 : tensor -> index %21 = shape.rank %19#1 : tensor -> index %22 = arith.cmpi sgt, %20, %21 : index diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc index 178e899cb33a72..dcbe88c3048eae 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc @@ -15,19 +15,18 @@ limitations under the License. #include "mlir/InitAllDialects.h" // from @llvm-project #include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project #include "stablehlo/dialect/Register.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" -#include "xla/mlir_hlo/lhlo/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" int main(int argc, char **argv) { mlir::registerAllPasses(); mlir::mhlo::registerAllMhloPasses(); - mlir::lmhlo::registerAllLmhloPasses(); mlir::kernel_gen::registerKernelGenPasses(); mlir::DialectRegistry registry; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index d1c3af0b9a6191..489e13d172c059 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -39,6 +39,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], @@ -75,6 +76,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc index 37999960cd69e7..45dbbf993bb6be 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" @@ -115,7 +116,7 @@ class BufferReuseAnalysis { // Find reuse candidates for the regarded allocation. SmallVector local_reuse_candidates; for (BlockArgument old_buffer : arguments) { - if (!old_buffer.getType().isa()) continue; + if (!mlir::isa(old_buffer.getType())) continue; // Lifetime criterion: Only reuse buffers that are no longer used on // first reuse, i.e. they are no longer alive. @@ -177,15 +178,16 @@ class BufferReuseAnalysis { std::vector get_buffer_arguments(func::FuncOp &f) { std::vector buffer_arguments; for (BlockArgument arg : f.getArguments()) { - if (arg.getType().isa()) buffer_arguments.push_back(arg); + if (mlir::isa(arg.getType())) + buffer_arguments.push_back(arg); } return buffer_arguments; } bool can_reuse_locally(Operation *op, Value old_buffer, Value new_buffer) { // For now, we support only memrefs with the same memory layout. - auto old_buffer_ty = old_buffer.getType().dyn_cast(); - auto new_buffer_ty = old_buffer.getType().dyn_cast(); + auto old_buffer_ty = mlir::dyn_cast(old_buffer.getType()); + auto new_buffer_ty = mlir::dyn_cast(old_buffer.getType()); if (!old_buffer_ty || !new_buffer_ty || old_buffer_ty.getLayout() != new_buffer_ty.getLayout()) return false; @@ -205,7 +207,7 @@ class BufferReuseAnalysis { // Allow dropping dimensions but no permutations. int64_t i = -1; for (AffineExpr expr : map.getResults()) { - auto dim_expr = expr.dyn_cast(); + auto dim_expr = mlir::dyn_cast(expr); if (!dim_expr || dim_expr.getPosition() <= i) return false; i = dim_expr.getPosition(); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc index 32faed506e52b4..9f41b399e2fd7f 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" namespace mlir { @@ -135,7 +136,7 @@ void RemoveCopyIfTargetIsFunctionArg(func::FuncOp func) { Block &body = func.getBody().front(); for (auto &op : llvm::reverse(body.without_terminator())) { if (auto copy = dyn_cast(op)) { - auto block_arg = copy.getTarget().dyn_cast(); + auto block_arg = mlir::dyn_cast(copy.getTarget()); if (!block_arg) break; if (!isa(block_arg.getOwner()->getParentOp()) || !block_arg.hasOneUse()) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc index b7ad2d4d28b129..a6f23f1ad43aa8 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" @@ -64,7 +65,7 @@ std::optional FindOpKernelContext(Operation *op) { return std::nullopt; } Value ctx = func.getArgument(0); - if (!ctx.getType().isa()) { + if (!mlir::isa(ctx.getType())) { return std::nullopt; } return ctx; @@ -114,7 +115,8 @@ struct DeallocOpConverter : public OpConversionPattern { if (!ctx) return failure(); // Operand with no layout is expected. - auto operand_memref_type = dealloc.getMemref().getType().cast(); + auto operand_memref_type = + mlir::cast(dealloc.getMemref().getType()); if (!operand_memref_type.getLayout().isIdentity()) { return failure(); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc index ed1138849e5a06..b5b22008dcb951 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" @@ -71,7 +72,7 @@ class EmbedTFFrameworkPass } FunctionType func_type = op.getFunctionType(); return func_type.getNumInputs() > 0 && - func_type.getInput(0).isa(); + mlir::isa(func_type.getInput(0)); }); target.addDynamicallyLegalOp(IsNotInsideTfEntryFunction); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc index a6b24b1a3afcc3..fa6ba2491d5906 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc @@ -217,7 +217,7 @@ class ShapeEqualityKnowledge { } if (auto alloc = dyn_cast(op)) { SmallVector shape; - ShapedType type = alloc.getResult().getType().cast(); + ShapedType type = mlir::cast(alloc.getResult().getType()); fillShapeFromAllocLike(alloc.getDynamicSizes(), type, shape); registerAssociation(ShapeValue{shape}, alloc.getResult()); return; @@ -225,7 +225,7 @@ class ShapeEqualityKnowledge { if (auto alloc = dyn_cast(op)) { // Construct a symbol representing the allocated shape. SmallVector shape; - ShapedType type = alloc.getResult().getType().cast(); + ShapedType type = mlir::cast(alloc.getResult().getType()); fillShapeFromAllocLike(alloc.getDynSizes(), type, shape); registerAssociation(ShapeValue{shape}, alloc.getResult()); return; @@ -331,7 +331,7 @@ struct PropagateShapeKnowledgeToKernels // Position of the kernel argument we are currently at. int kernel_p = 0; for (auto operand : launch.getKernelOperands()) { - auto memref = operand.getType().dyn_cast(); + auto memref = mlir::dyn_cast(operand.getType()); if (!memref) { // Scalar argument, advance kernel position by one. kernel_p++; @@ -341,7 +341,7 @@ struct PropagateShapeKnowledgeToKernels if (!knowledge.haveSameShape(operand, previous.first)) { continue; } - auto previous_type = previous.first.getType().cast(); + auto previous_type = mlir::cast(previous.first.getType()); // We use the first equality found and replace uses of corresponding // size and (potentially) stride information here. auto args_to_replace = memref.getRank(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc index 89ecd6da13be74..a7d26813239571 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc @@ -56,7 +56,7 @@ struct PropagateTfAbiKnowledgeToKernelsPass // the inner stride is one. // TODO(herhut): Insert asserts in debug mode to check this. for (auto argument : function.getArguments()) { - if (argument.getType().isa()) { + if (mlir::isa(argument.getType())) { worklist.push_back(argument); allocated_by_tf_runtime.insert(argument); offset_is_zero.insert(argument); @@ -95,7 +95,7 @@ struct PropagateTfAbiKnowledgeToKernelsPass llvm::SmallDenseMap constants; auto loc = kernel.getLoc(); for (auto operand : launch.getKernelOperands()) { - auto memref = operand.getType().dyn_cast(); + auto memref = mlir::dyn_cast(operand.getType()); if (!memref) { // Scalar argument, advance kernel position by one. kernel_p++; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc index cffa5e7b44691e..8748b188f35dfa 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" @@ -96,14 +97,14 @@ class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern { Location loc, Type size_ty, Type element_ty, std::optional attr, ConversionPatternRewriter *rewriter) const { - assert(size_ty.isa() && "expect integer size type"); - assert(element_ty.isa() && "expect integer element type"); + assert(mlir::isa(size_ty) && "expect integer size type"); + assert(mlir::isa(element_ty) && "expect integer element type"); return ConvertArrayAttrToStackAllocatedArray( loc, size_ty, element_ty, attr, rewriter, [&](Attribute attr) { return rewriter->create( loc, element_ty, rewriter->getIntegerAttr(element_ty, - attr.cast().getInt())); + mlir::cast(attr).getInt())); }); } }; @@ -227,7 +228,7 @@ class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern { TFDeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // TODO(herhut) Support unranked memrefs. - if (!op.getMemref().getType().isa()) return failure(); + if (!mlir::isa(op.getMemref().getType())) return failure(); MemRefDescriptor memref(adaptor.getMemref()); Value allocated_bytes_ptr = memref.allocatedPtr(rewriter, op.getLoc()); @@ -429,7 +430,7 @@ class ReportErrorOpConverter std::string err_str; llvm::raw_string_ostream err_stream(err_str); err_stream << message; - if (!loc.isa()) { + if (!mlir::isa(loc)) { err_stream << " at "; loc.print(err_stream); } @@ -465,16 +466,18 @@ class NullMemRefOpConverter : public ConvertOpToLLVMPattern { MLIRContext *ctx = null_memref_op.getContext(); mlir::Operation *op = null_memref_op.getOperation(); - auto shaped_result_type = null_memref_op.getType().cast(); - auto mem_space = - shaped_result_type.getMemorySpace().dyn_cast_or_null(); + auto shaped_result_type = + mlir::cast(null_memref_op.getType()); + auto mem_space = mlir::dyn_cast_or_null( + shaped_result_type.getMemorySpace()); unsigned address_space = static_cast(mem_space ? mem_space.getInt() : 0); LLVM::LLVMPointerType llvm_ptr_type = LLVM::LLVMPointerType::get(ctx, address_space); Value zero = createIndexAttrConstant(rewriter, loc, getIndexType(), 0); - if (auto result_type = null_memref_op.getType().dyn_cast()) { + if (auto result_type = + mlir::dyn_cast(null_memref_op.getType())) { // Set all dynamic sizes to 1 and compute fake strides. SmallVector dyn_sizes( result_type.getNumDynamicDims(), @@ -497,7 +500,7 @@ class NullMemRefOpConverter : public ConvertOpToLLVMPattern { return success(); } - auto result_type = null_memref_op.getType().cast(); + auto result_type = mlir::cast(null_memref_op.getType()); Type llvm_result_type = type_converter.convertType(result_type); auto desc = @@ -506,7 +509,7 @@ class NullMemRefOpConverter : public ConvertOpToLLVMPattern { // Extract address space and element type. auto targetType = - null_memref_op.getResult().getType().cast(); + mlir::cast(null_memref_op.getResult().getType()); unsigned addressSpace = *getTypeConverter()->getMemRefAddressSpace(targetType); @@ -549,7 +552,7 @@ class IsValidMemRefOpConverter MemRefDescriptor desc(adaptor.getArg()); // Compare every size in the descriptor to 0 to check num_elements == 0. - int64_t rank = op.getArg().getType().cast().getRank(); + int64_t rank = mlir::cast(op.getArg().getType()).getRank(); Value is_empty_shape = rewriter.create( loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); Value zero = createIndexAttrConstant(rewriter, loc, getIndexType(), 0); diff --git a/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir b/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir index 5719fd35989a6b..64fcdfc18d081f 100644 --- a/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir @@ -4,7 +4,7 @@ module { // CHECK-LABEL: @main func.func @main(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) { - // CHECK: "tfl.call_once"() {session_init_function = "NoOp", session_init_function_symbol = @NoOp} : () -> () + // CHECK: "tfl.call_once"() <{session_init_function = "NoOp"}> {session_init_function_symbol = @NoOp} : () -> () "tfl.call_once"() {session_init_function = "NoOp"} : () -> () %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32> diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir index 53cbd84d6e441f..11648c9572b63c 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir @@ -100,7 +100,7 @@ func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> te // ----- // CHECK-LABEL: test_real_div -// CHECK: %[[VAR0:.*]] = tosa.div %arg0, %arg1 +// CHECK: %[[VAR0:.*]] = tosa.int_div %arg0, %arg1 func.func @test_real_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { %2 = "tf.RealDiv"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> func.return %2 : tensor<13x21x3xi32> @@ -109,7 +109,7 @@ func.func @test_real_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) // ----- // CHECK-LABEL: test_floor_div -// CHECK: %[[VAR0:.*]] = tosa.div %arg0, %arg1 +// CHECK: %[[VAR0:.*]] = tosa.int_div %arg0, %arg1 func.func @test_floor_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { %2 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> func.return %2 : tensor<13x21x3xi32> diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir index c73ece4991d513..77bafd5bc1ba9a 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir @@ -4,7 +4,7 @@ // CHECK-LABEL: test_conv2d // CHECK-DAG: %[[VAR0:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> -// CHECK: %[[VAR1:.*]] = "tfl.conv_2d"(%arg0, %arg1, %[[VAR0]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} +// CHECK: %[[VAR1:.*]] = "tfl.conv_2d"(%arg0, %arg1, %[[VAR0]]) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> func.func @test_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>) -> tensor<*xf32> { %cst = arith.constant dense<0.000000e+00> : tensor<16xf32> %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<*xf32> @@ -15,7 +15,7 @@ func.func @test_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32> // CHECK-LABEL: func @test_softmax( // CHECK-SAME:%[[VAR0:.*]]: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { -// CHECK: %[[VAR1:.*]] = "tfl.softmax"(%[[VAR0]]) {beta = 1.000000e+00 : f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAR1:.*]] = "tfl.softmax"(%[[VAR0]]) <{beta = 1.000000e+00 : f32}> : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: return %[[VAR1]] : tensor<13x21x3xf32> func.func @test_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 57a5fdf02205f0..5ee9eaf4cd5517 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -370,7 +370,7 @@ func.func @test_rcp(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_div // CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %arg1 -// CHECK: %[[VAR0:.*]] = tosa.div %arg0, %[[RESHAPE]] +// CHECK: %[[VAR0:.*]] = tosa.int_div %arg0, %[[RESHAPE]] func.func @test_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor) -> tensor<*xi32> { %0 = "tfl.div"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xi32>, tensor) -> tensor<*xi32> func.return %0 : tensor<*xi32> @@ -380,7 +380,7 @@ func.func @test_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor) -> tensor<*x // CHECK-LABEL: test_floor_div // CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %arg1 -// CHECK: %[[VAR0:.*]] = tosa.div %arg0, %[[RESHAPE]] +// CHECK: %[[VAR0:.*]] = tosa.int_div %arg0, %[[RESHAPE]] func.func @test_floor_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor) -> tensor<*xi32> { %0 = "tfl.floor_div"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xi32>, tensor) -> tensor<*xi32> func.return %0 : tensor<*xi32> diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir index 401089a6d7cb99..6ad25dca4b8abd 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir @@ -65,7 +65,7 @@ module { module { // CHECK-LABEL: @nostate // CHECK: %[[VAL_0:.*]]: tensor<16x16xf32>) -> tensor<16x16xf32> { - // CHECK: %[[VAL_1:.*]] = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + // CHECK: %[[VAL_1:.*]] = "tfl.var_handle"() <{container = "", shared_name = "Variable"}> : () -> tensor<*x!tf_type.resource> // CHECK: %[[VAL_2:.*]] = "tfl.read_variable"(%[[VAL_1]]) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32> // CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_2]], %[[VAL_0]] : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> // CHECK: "tfl.assign_variable"(%[[VAL_1]], %[[VAL_3]]) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () diff --git a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc index 350d9e47545fb0..6523824611a603 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -90,7 +91,7 @@ struct ConvertUint8QConstOp : public RewritePattern { } mlir::DenseElementsAttr src_dense_attr = - tfl_qconst_op.getValue().cast(); + mlir::cast(tfl_qconst_op.getValue()); double type_range_min = static_cast(output_element_type.getStorageTypeMin() - diff --git a/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc b/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc index b64e4eda6d5e37..ba194e3e81c964 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -52,8 +53,8 @@ LogicalResult TosaDequantizeTFLSoftmaxPattern::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { TFL::SoftmaxOp tfl_softmax_op = cast(op); RankedTensorType input_type = - tfl_softmax_op.getInput().getType().cast(); - if (!input_type.getElementType().isa()) { + mlir::cast(tfl_softmax_op.getInput().getType()); + if (!mlir::isa(input_type.getElementType())) { return failure(); } Location loc = tfl_softmax_op.getLoc(); diff --git a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc index efee9aa9e9b9c2..ff07b9d6f91039 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" @@ -80,7 +81,7 @@ LogicalResult ConvertTFBiasAddOp::matchAndRewrite( auto value = tf_biasadd_op.getValue(); auto bias = tf_biasadd_op.getBias(); - auto bias_shape = bias.getType().cast().getShape(); + auto bias_shape = mlir::cast(bias.getType()).getShape(); if (bias_shape.size() != 1) { return rewriter.notifyMatchFailure(op, "bias tensor must be rank 1"); } @@ -89,7 +90,8 @@ LogicalResult ConvertTFBiasAddOp::matchAndRewrite( llvm::dyn_cast_if_present(value.getDefiningOp())) { // Sanity check to confirm rhs() has the expected shape of bias auto filter_shape = - tf_conv2d_op.getFilter().getType().cast().getShape(); + mlir::cast(tf_conv2d_op.getFilter().getType()) + .getShape(); // Assume the filter shape is [H, W, I, O] if (filter_shape.back() != bias_shape.back()) { @@ -114,7 +116,8 @@ LogicalResult ConvertTFBiasAddOp::matchAndRewrite( llvm::dyn_cast_if_present(value.getDefiningOp())) { // Sanity check to confirm rhs() has the expected shape of bias auto filter_shape = - tf_conv3d_op.getFilter().getType().cast().getShape(); + mlir::cast(tf_conv3d_op.getFilter().getType()) + .getShape(); // Assume the filter shape is [D, H, W, I, O] if (filter_shape.back() != bias_shape.back()) { diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 3b461b8b36ae42..25707c2bde1331 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -56,6 +56,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h" @@ -571,7 +572,7 @@ std::optional convertZerosLikeOp(PatternRewriter& rewriter, Attribute zero_attr = rewriter.getZeroAttr(zero_type); return CreateOpAndInfer(rewriter, op->getLoc(), zero_type, - zero_attr.cast()) + mlir::cast(zero_attr)) .getResult(); } @@ -586,12 +587,12 @@ std::optional convertMultiplyOp(PatternRewriter& rewriter, Operation* op, // Not a shaped tensor output if (!input_lhs_type || !input_rhs_type || !output_type) return std::nullopt; - bool input_lhs_is_qtype = - input_lhs_type.getElementType().isa(); - bool input_rhs_is_qtype = - input_rhs_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + bool input_lhs_is_qtype = mlir::isa( + input_lhs_type.getElementType()); + bool input_rhs_is_qtype = mlir::isa( + input_rhs_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_lhs_is_qtype != output_is_qtype || input_rhs_is_qtype != output_is_qtype) { @@ -603,12 +604,12 @@ std::optional convertMultiplyOp(PatternRewriter& rewriter, Operation* op, if (output_is_qtype) { ShapedType rescale_type = output_type.clone(rewriter.getI32Type()); - auto input_lhs_qtype = input_lhs_type.getElementType() - .cast(); - auto input_rhs_qtype = input_rhs_type.getElementType() - .cast(); - auto output_qtype = - output_type.getElementType().cast(); + auto input_lhs_qtype = mlir::cast( + input_lhs_type.getElementType()); + auto input_rhs_qtype = mlir::cast( + input_rhs_type.getElementType()); + auto output_qtype = mlir::cast( + output_type.getElementType()); // MLIR store scale as double, but TFLite store scale as float // Downcasting from double to float to match TFLite behavior @@ -661,11 +662,11 @@ std::optional convertSquaredDifferenceOp(PatternRewriter& rewriter, } bool x_is_qtype = - x_type.getElementType().isa(); + mlir::isa(x_type.getElementType()); bool y_is_qtype = - y_type.getElementType().isa(); - bool result_is_qtype = - result_type.getElementType().isa(); + mlir::isa(y_type.getElementType()); + bool result_is_qtype = mlir::isa( + result_type.getElementType()); if (x_is_qtype != result_is_qtype || y_is_qtype != result_is_qtype) { (void)rewriter.notifyMatchFailure( @@ -678,11 +679,11 @@ std::optional convertSquaredDifferenceOp(PatternRewriter& rewriter, // Then scale back to I8 if (result_is_qtype) { auto x_qtype = - x_type.getElementType().cast(); + mlir::cast(x_type.getElementType()); auto y_qtype = - y_type.getElementType().cast(); - auto result_qtype = - result_type.getElementType().cast(); + mlir::cast(y_type.getElementType()); + auto result_qtype = mlir::cast( + result_type.getElementType()); uint32_t result_bits = result_qtype.getStorageTypeIntegralWidth(); @@ -779,16 +780,16 @@ std::optional convertConcatV2Op(PatternRewriter& rewriter, Operation* op, } mlir::quant::UniformQuantizedType result_quant_type = - result_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + result_type.getElementType()); SmallVector values_rescaled; for (auto v : values) { RankedTensorType operand_type = dyn_cast(v.getType()); mlir::quant::UniformQuantizedType operand_quant_type = - operand_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + operand_type.getElementType()); // tfl.concat currently allows different scales for each input tensor, which // TFlite team will fix in: @@ -818,7 +819,8 @@ std::optional convertConcatV2Op(PatternRewriter& rewriter, Operation* op, } } - int32_t tensor_rank = values[0].getType().cast().getRank(); + int32_t tensor_rank = + mlir::cast(values[0].getType()).getRank(); if (axis < 0) axis += tensor_rank; if ((axis < 0) || (axis > tensor_rank)) { @@ -1046,7 +1048,8 @@ std::optional convertSpaceToBatchNDOp(PatternRewriter& rewriter, // [padded_shape[M] / block_shape[M-1]] + // remaining_shape int32_t a2_reshape_a1_rank = - a2_reshape_a1_op.getResult().getType().cast().getRank(); + mlir::cast(a2_reshape_a1_op.getResult().getType()) + .getRank(); SmallVector a3_perm(a2_reshape_a1_rank); SmallVector a3_transpose_shape(a2_reshape_a1_rank); @@ -1579,17 +1582,19 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, int32_t input_rank = input_type.getShape().size(); ArrayRef logits_shape = output_type.getShape(); - if (input_type.getElementType().isa() && - output_type.getElementType().isa()) { + if (mlir::isa(input_type.getElementType()) && + mlir::isa(output_type.getElementType())) { SmallVector rsum_shape_v(input_type.getShape().begin(), input_type.getShape().end() - 1); rsum_shape_v.push_back(1); ArrayRef rsum_shape(rsum_shape_v); // The if condition already checks if these are UQTs mlir::quant::UniformQuantizedType in_quant_type = - input_type.getElementType().cast(); + mlir::cast( + input_type.getElementType()); mlir::quant::UniformQuantizedType out_quant_type = - output_type.getElementType().cast(); + mlir::cast( + output_type.getElementType()); auto int16_element_qtype = mlir::quant::UniformQuantizedType::get( true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0, @@ -2005,11 +2010,11 @@ std::optional convertLogSoftmaxOp(PatternRewriter& rewriter, } mlir::quant::UniformQuantizedType in_quant_type = - input_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + input_type.getElementType()); mlir::quant::UniformQuantizedType out_quant_type = - output_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + output_type.getElementType()); if (in_quant_type || out_quant_type) { (void)rewriter.notifyMatchFailure( op, "quantized log_softmax lowering not implemented yet"); @@ -2271,7 +2276,8 @@ std::optional> convertSplitOp( tensorflow::ConvertMlirShapeToTF(new_shape))); } - RankedTensorType slice_type = slice_value.getType().cast(); + RankedTensorType slice_type = + mlir::cast(slice_value.getType()); assert((slice_type.getDimSize(axis) % num_split) == 0); // Each slice has a different beginning point. @@ -2442,7 +2448,7 @@ std::optional convertStridedSliceOp( // Limitations: // * This implementation only supports ellipsis_mask=0 for now auto input_type = dyn_cast(input_value.getType()); - ShapedType result_type = result_value.getType().cast(); + ShapedType result_type = mlir::cast(result_value.getType()); if (ellipsis_mask != 0) { (void)rewriter.notifyMatchFailure(op, "ellipses mask not supported yet"); @@ -2586,7 +2592,7 @@ std::optional convertStridedSliceOp( if (all_strides_one) { auto reversed = reverseNegativeStride(rewriter, op, a1_slice_op.getResult(), strides); - auto shape = reversed.getType().cast().getShape(); + auto shape = mlir::cast(reversed.getType()).getShape(); SmallVector new_shape; for (int i = 0; i < input_rank; ++i) { @@ -2684,9 +2690,9 @@ std::optional convertFloorDivOp(PatternRewriter& rewriter, Operation* op, Type element_type = output_type.getElementType(); - if (element_type.isa()) { - return CreateOpAndInfer(rewriter, op->getLoc(), output_type, - lhs_value, rhs_value) + if (mlir::isa(element_type)) { + return CreateOpAndInfer(rewriter, op->getLoc(), output_type, + lhs_value, rhs_value) .getResult(); } @@ -2738,14 +2744,14 @@ std::optional convertFusedActivation(PatternRewriter& rewriter, if (!input_type) return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); if (input_is_qtype) { // We can always make output/input tensor's scale/zp always be the same // when legalizing fused_activation_function, as it's generated during // legalization. - auto input_qtype = - input_type.getElementType().cast(); + auto input_qtype = mlir::cast( + input_type.getElementType()); if (fused_activation_fn.getValue() == "NONE") { return input_value; @@ -3079,9 +3085,9 @@ std::optional convertReduceProdOp(PatternRewriter& rewriter, if (!input_type) return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype || output_is_qtype) { (void)rewriter.notifyMatchFailure( @@ -3105,9 +3111,9 @@ std::optional convertReduceSumOp(PatternRewriter& rewriter, if (!input_type) return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { (void)rewriter.notifyMatchFailure( @@ -3123,10 +3129,10 @@ std::optional convertReduceSumOp(PatternRewriter& rewriter, Type reduce_element_type = input_type.getElementType(); if (input_is_qtype) { - auto input_qtype = - input_type.getElementType().cast(); - auto output_qtype = - output_type.getElementType().cast(); + auto input_qtype = mlir::cast( + input_type.getElementType()); + auto output_qtype = mlir::cast( + output_type.getElementType()); int32_t input_shift = 20; @@ -3164,9 +3170,9 @@ std::optional convertReduceMeanOp(PatternRewriter& rewriter, if (!input_type) return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { (void)rewriter.notifyMatchFailure( @@ -3176,7 +3182,8 @@ std::optional convertReduceMeanOp(PatternRewriter& rewriter, } // Only supports float type mean() if it's non-quantized - if (!input_is_qtype && !output_type.getElementType().isa()) { + if (!input_is_qtype && + !mlir::isa(output_type.getElementType())) { op->emitWarning("input unquantized type but output element not FloatType"); return std::nullopt; } @@ -3206,10 +3213,10 @@ std::optional convertReduceMeanOp(PatternRewriter& rewriter, int32_t output_scale_shift = 0; if (input_is_qtype) { - auto input_qtype = - input_type.getElementType().cast(); - auto output_qtype = - output_type.getElementType().cast(); + auto input_qtype = mlir::cast( + input_type.getElementType()); + auto output_qtype = mlir::cast( + output_type.getElementType()); const int32_t scale_width = 32; computeMultiplierAndShift(1.0f, input_scale_multiplier, input_scale_shift, @@ -3275,9 +3282,9 @@ std::optional convertResizeOp(PatternRewriter& rewriter, Operation* op, } bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { (void)rewriter.notifyMatchFailure( @@ -3287,7 +3294,7 @@ std::optional convertResizeOp(PatternRewriter& rewriter, Operation* op, } if (!input_is_qtype) { - if (!input_type.getElementType().isa()) { + if (!mlir::isa(input_type.getElementType())) { (void)rewriter.notifyMatchFailure( op, "only quantized or float types supported"); return std::nullopt; @@ -3406,8 +3413,8 @@ std::optional convertResizeOp(PatternRewriter& rewriter, Operation* op, // If quantized bilinear mode, need to lower to RESIZE + RESCALE pair. if (is_bilinear) { RankedTensorType output_acc_type; - auto input_element_qtype = - input_type.getElementType().cast(); + auto input_element_qtype = mlir::cast( + input_type.getElementType()); bool is_scale32; @@ -3505,7 +3512,7 @@ std::optional convertQuantizeOp(PatternRewriter& rewriter, Operation* op, auto output_element_type = output_type.getElementType(); // output element type could only be quantized integer - if (!output_element_type.isa()) { + if (!mlir::isa(output_element_type)) { (void)rewriter.notifyMatchFailure( op, "lowering quantizeOp but output element type not quantized"); return std::nullopt; @@ -3546,7 +3553,7 @@ std::optional convertDequantizeOp(PatternRewriter& rewriter, if (!input_type) return std::nullopt; // input element type could only be quantized integer - if (!input_type.getElementType().isa()) + if (!mlir::isa(input_type.getElementType())) return std::nullopt; std::optional zp_val; @@ -3839,8 +3846,8 @@ std::optional convertTFConv2DCommon( stride = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now - int64_t stride_h = strides_attr[1].cast().getInt(); - int64_t stride_w = strides_attr[2].cast().getInt(); + int64_t stride_h = mlir::cast(strides_attr[1]).getInt(); + int64_t stride_w = mlir::cast(strides_attr[2]).getInt(); stride = rewriter.getDenseI64ArrayAttr({stride_h, stride_w}); } } @@ -3849,8 +3856,8 @@ std::optional convertTFConv2DCommon( dilation = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now - int64_t dilation_h = dilations_attr[1].cast().getInt(); - int64_t dilation_w = dilations_attr[2].cast().getInt(); + int64_t dilation_h = mlir::cast(dilations_attr[1]).getInt(); + int64_t dilation_w = mlir::cast(dilations_attr[2]).getInt(); dilation = rewriter.getDenseI64ArrayAttr({dilation_h, dilation_w}); } } @@ -3915,8 +3922,8 @@ std::optional convertConv3DCommon(PatternRewriter& rewriter, DenseI64ArrayAttr strides_attr = rewriter.getDenseI64ArrayAttr(strides); DenseI64ArrayAttr dilations_attr = rewriter.getDenseI64ArrayAttr(dilations); - RankedTensorType input_type = input.getType().cast(); - RankedTensorType filter_type = filter.getType().cast(); + RankedTensorType input_type = mlir::cast(input.getType()); + RankedTensorType filter_type = mlir::cast(filter.getType()); DenseI64ArrayAttr pads_attr; if (!getPaddingValuesFromPadType(tf_pad, data_format_tf, 0, input_type, @@ -3963,9 +3970,9 @@ std::optional convertTFConv3DCommon( // Defaults to [1, 1, 1]. strides = {1, 1, 1}; } else { - int64_t stride_d = strides_attr[1].cast().getInt(); - int64_t stride_h = strides_attr[2].cast().getInt(); - int64_t stride_w = strides_attr[3].cast().getInt(); + int64_t stride_d = mlir::cast(strides_attr[1]).getInt(); + int64_t stride_h = mlir::cast(strides_attr[2]).getInt(); + int64_t stride_w = mlir::cast(strides_attr[3]).getInt(); strides = {stride_d, stride_h, stride_w}; } @@ -3974,9 +3981,9 @@ std::optional convertTFConv3DCommon( // Defaults to [1, 1, 1]. dilations = {1, 1, 1}; } else { - int64_t dilation_d = dilations_attr[1].cast().getInt(); - int64_t dilation_h = dilations_attr[2].cast().getInt(); - int64_t dilation_w = dilations_attr[3].cast().getInt(); + int64_t dilation_d = mlir::cast(dilations_attr[1]).getInt(); + int64_t dilation_h = mlir::cast(dilations_attr[2]).getInt(); + int64_t dilation_w = mlir::cast(dilations_attr[3]).getInt(); dilations = {dilation_d, dilation_h, dilation_w}; } @@ -4686,7 +4693,7 @@ std::optional convertSinOp(PatternRewriter& rewriter, Operation* op, std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, Value input, RankedTensorType output_type) { auto output_elem_type = output_type.getElementType(); - if (output_elem_type.isa()) { + if (mlir::isa(output_elem_type)) { (void)rewriter.notifyMatchFailure(op, "tfl quantization not yet supported"); return std::nullopt; } @@ -4695,7 +4702,7 @@ std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, // one element. Value pos_one, neg_one, zero; ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - if (output_elem_type.isa()) { + if (mlir::isa(output_elem_type)) { pos_one = getTosaConstTensorSingleF32(rewriter, op, 1.0f); neg_one = getTosaConstTensorSingleF32(rewriter, op, -1.0f); zero = getTosaConstTensorSingleF32(rewriter, op, 0.0f); @@ -4733,7 +4740,7 @@ std::optional convertBroadcastToOp(PatternRewriter& rewriter, } Type element_type = input_type.getElementType(); - if (element_type.isa()) { + if (mlir::isa(element_type)) { (void)rewriter.notifyMatchFailure(op, "input element type is complex"); return std::nullopt; } @@ -4816,7 +4823,7 @@ std::optional convertBroadcastToOp(PatternRewriter& rewriter, RankedTensorType output_type = tensorflow::GetTypeFromTFTensorShape(new_shape, element_type); - if (element_type.isa()) { + if (mlir::isa(element_type)) { // F32: legalize to broadcastable Add with (-0.f), instead of 0.f. // This is to preserve original values: // for corner case where x = -0.f diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index 0ab48cc417fc98..904394d370bcce 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -257,7 +257,7 @@ LogicalResult ConvertTFSignOp::matchAndRewrite( auto tf_sign_op = cast(op); RankedTensorType output_type = - tf_sign_op.getResult().getType().cast(); + mlir::cast(tf_sign_op.getResult().getType()); std::optional result = convertSignOp(rewriter, op, tf_sign_op.getX(), output_type); @@ -270,7 +270,8 @@ LogicalResult ConvertTFSignOp::matchAndRewrite( LogicalResult ConvertTFSinOp::matchAndRewrite(Operation* op, PatternRewriter& rewriter) const { auto tf_sin_op = cast(op); - ShapedType output_type = tf_sin_op.getResult().getType().cast(); + ShapedType output_type = + mlir::cast(tf_sin_op.getResult().getType()); std::optional result = convertSinOp(rewriter, op, tf_sin_op.getX(), output_type); @@ -289,8 +290,8 @@ LogicalResult ConvertTFCosOp::matchAndRewrite(Operation* op, if (!input_ty || !output_ty) return failure(); - bool input_is_fp = input_ty.getElementType().isa(); - bool output_is_fp = output_ty.getElementType().isa(); + bool input_is_fp = mlir::isa(input_ty.getElementType()); + bool output_is_fp = mlir::isa(output_ty.getElementType()); if (!input_is_fp || !output_is_fp) { return rewriter.notifyMatchFailure( @@ -427,7 +428,7 @@ LogicalResult ConvertTFRoundOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "input not tensor type"); } - if (input_type.getElementType().isa()) { + if (mlir::isa(input_type.getElementType())) { std::optional result = convertRoundOp( rewriter, op, tf_round_op.getResult(), tf_round_op.getX()); @@ -519,9 +520,9 @@ LogicalResult ConvertTFRealDivOp::matchAndRewrite( Type element_type = output_type.getElementType(); - if (element_type.isa()) { - CreateReplaceOpAndInfer(rewriter, op, output_type, - tf_div_op.getX(), tf_div_op.getY()); + if (mlir::isa(element_type)) { + CreateReplaceOpAndInfer(rewriter, op, output_type, + tf_div_op.getX(), tf_div_op.getY()); return success(); } @@ -717,7 +718,8 @@ LogicalResult ConvertTFMaxPoolOp::matchAndRewrite( LogicalResult ConvertTFConcatV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_concatv2_op = cast(op); - auto result_type = tf_concatv2_op.getResult().getType().cast(); + auto result_type = + mlir::cast(tf_concatv2_op.getResult().getType()); SmallVector values(tf_concatv2_op.getValues()); ElementsAttr axis_elems; @@ -877,7 +879,7 @@ LogicalResult ConvertTFFillOp::matchAndRewrite( DenseArrayAttr fill_attr; // Convert to a compatible zero type - if (value_elem.getShapedType().getElementType().isa()) { + if (mlir::isa(value_elem.getShapedType().getElementType())) { SmallVector fill_arr( total_size, value_elem.getValues()[0].getValue().convertToFloat()); @@ -891,7 +893,7 @@ LogicalResult ConvertTFFillOp::matchAndRewrite( DenseI32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(fill_arr)); } auto fill_const_op = CreateOpAndInfer( - rewriter, op->getLoc(), fill_type, fill_attr.cast()); + rewriter, op->getLoc(), fill_type, mlir::cast(fill_attr)); rewriter.replaceOp(op, {fill_const_op.getResult()}); return success(); @@ -911,8 +913,8 @@ LogicalResult ConvertTFConv2DOp::matchAndRewrite( RankedTensorType bias_type = tensorflow::GetTypeFromTFTensorShape( {bias_dim}, filter_type.getElementType()); auto bias_attr = rewriter.getZeroAttr(bias_type); - auto bias = CreateOpAndInfer(rewriter, op->getLoc(), bias_type, - bias_attr.cast()); + auto bias = CreateOpAndInfer( + rewriter, op->getLoc(), bias_type, mlir::cast(bias_attr)); std::optional result = convertTFConv2DCommon( rewriter, op, output_type, tf_conv2d_op.getInput(), @@ -946,8 +948,8 @@ LogicalResult ConvertTFConv3DOp::matchAndRewrite( RankedTensorType bias_type = RankedTensorType::get({bias_dim}, filter_type.getElementType()); auto bias_attr = rewriter.getZeroAttr(bias_type); - auto bias = CreateOpAndInfer(rewriter, op->getLoc(), bias_type, - bias_attr.cast()); + auto bias = CreateOpAndInfer( + rewriter, op->getLoc(), bias_type, mlir::cast(bias_attr)); std::optional result = convertTFConv3DCommon( rewriter, op, output_type, tf_conv3d_op.getInput(), @@ -1036,8 +1038,8 @@ LogicalResult ConvertTFDepthwiseConv2dNativeOp::matchAndRewrite( RankedTensorType bias_type = tensorflow::GetTypeFromTFTensorShape( {bias_dim}, filter_type.getElementType()); auto bias_attr = rewriter.getZeroAttr(bias_type); - auto bias = CreateOpAndInfer(rewriter, op->getLoc(), bias_type, - bias_attr.cast()); + auto bias = CreateOpAndInfer( + rewriter, op->getLoc(), bias_type, mlir::cast(bias_attr)); CreateReplaceOpAndInfer( rewriter, op, output_type, tf_dwconv2d_op.getInput(), diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index b0c9a7189a50aa..abb5b80b92bfeb 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -321,9 +321,9 @@ LogicalResult ConvertTFLReluOp::matchAndRewrite( if (!input_type || !output_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { return rewriter.notifyMatchFailure( @@ -373,9 +373,9 @@ LogicalResult ConvertTFLRelu1Op::matchAndRewrite( if (!input_type || !output_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { return rewriter.notifyMatchFailure( @@ -423,14 +423,15 @@ LogicalResult ConvertTFLRelu0To1Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_relu0to1_op = cast(op); - ShapedType input_type = tfl_relu0to1_op.getX().getType().cast(); + ShapedType input_type = + mlir::cast(tfl_relu0to1_op.getX().getType()); ShapedType output_type = - tfl_relu0to1_op.getResult().getType().cast(); + mlir::cast(tfl_relu0to1_op.getResult().getType()); bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { return rewriter.notifyMatchFailure( @@ -444,9 +445,11 @@ LogicalResult ConvertTFLRelu0To1Op::matchAndRewrite( if (output_is_qtype && input_is_qtype) { UniformQuantizedType input_qtype = - input_type.getElementType().cast(); + mlir::cast( + input_type.getElementType()); UniformQuantizedType output_qtype = - output_type.getElementType().cast(); + mlir::cast( + output_type.getElementType()); clamp_min = output_qtype.getZeroPoint(); @@ -482,9 +485,9 @@ LogicalResult ConvertTFLRelu6Op::matchAndRewrite( if (!input_type || !output_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { return rewriter.notifyMatchFailure( @@ -539,12 +542,12 @@ static LogicalResult prepareMatchAndRewriteComparison( // Not a shaped tensor output if (!input_x_type || !input_y_type || !output_type) return failure(); - bool input_x_is_qtype = - input_x_type.getElementType().isa(); - bool input_y_is_qtype = - input_y_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + bool input_x_is_qtype = mlir::isa( + input_x_type.getElementType()); + bool input_y_is_qtype = mlir::isa( + input_y_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_x_is_qtype != input_y_is_qtype || input_y_is_qtype != output_is_qtype) { @@ -671,20 +674,20 @@ static LogicalResult matchAndRewriteAddSub(Operation* op, auto tfl_add_op = cast(op); ShapedType input_lhs_type = - tfl_add_op.getLhs().getType().template dyn_cast(); + mlir::dyn_cast(tfl_add_op.getLhs().getType()); ShapedType input_rhs_type = - tfl_add_op.getRhs().getType().template dyn_cast(); + mlir::dyn_cast(tfl_add_op.getRhs().getType()); ShapedType output_type = - tfl_add_op.getResult().getType().template dyn_cast(); + mlir::dyn_cast(tfl_add_op.getResult().getType()); // Not a ranked tensor output if (!input_lhs_type || !input_rhs_type || !output_type) return failure(); - bool input_lhs_is_qtype = - input_lhs_type.getElementType().isa(); - bool input_rhs_is_qtype = - input_rhs_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + bool input_lhs_is_qtype = mlir::isa( + input_lhs_type.getElementType()); + bool input_rhs_is_qtype = mlir::isa( + input_rhs_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_lhs_is_qtype != output_is_qtype || input_rhs_is_qtype != output_is_qtype) { @@ -847,7 +850,7 @@ LogicalResult ConvertTFLSignOp::matchAndRewrite( auto tfl_sign_op = cast(op); RankedTensorType output_type = - tfl_sign_op.getResult().getType().cast(); + mlir::cast(tfl_sign_op.getResult().getType()); std::optional result = convertSignOp(rewriter, op, tfl_sign_op.getX(), output_type); @@ -932,7 +935,7 @@ LogicalResult ConvertTFLRoundOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "input not shaped tensor type"); } - if (input_type.getElementType().isa()) { + if (mlir::isa(input_type.getElementType())) { std::optional result = convertRoundOp( rewriter, op, tfl_round_op.getResult(), tfl_round_op.getX()); @@ -962,10 +965,11 @@ LogicalResult ConvertTFLDivOp::matchAndRewrite( Type element_type = output_type.getElementType(); Value div_op; - if (element_type.isa()) { + if (mlir::isa(element_type)) { div_op = - CreateOpAndInfer(rewriter, op->getLoc(), output_type, - tfl_div_op.getLhs(), tfl_div_op.getRhs()) + CreateOpAndInfer( + rewriter, op->getLoc(), output_type, tfl_div_op.getLhs(), + tfl_div_op.getRhs()) .getResult(); } else { auto reciprocal_op = CreateOpAndInfer( @@ -1006,12 +1010,12 @@ LogicalResult ConvertTFLMaximumOp::matchAndRewrite( // Not a shaped tensor output if (!input_lhs_type || !input_rhs_type || !output_type) return failure(); - bool input_lhs_is_qtype = - input_lhs_type.getElementType().isa(); - bool input_rhs_is_qtype = - input_rhs_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + bool input_lhs_is_qtype = mlir::isa( + input_lhs_type.getElementType()); + bool input_rhs_is_qtype = mlir::isa( + input_rhs_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_lhs_is_qtype != output_is_qtype || input_rhs_is_qtype != output_is_qtype) { @@ -1062,12 +1066,12 @@ LogicalResult ConvertTFLMinimumOp::matchAndRewrite( // Not a shaped tensor output if (!input_lhs_type || !input_rhs_type || !output_type) return failure(); - bool input_lhs_is_qtype = - input_lhs_type.getElementType().isa(); - bool input_rhs_is_qtype = - input_rhs_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + bool input_lhs_is_qtype = mlir::isa( + input_lhs_type.getElementType()); + bool input_rhs_is_qtype = mlir::isa( + input_rhs_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_lhs_is_qtype != output_is_qtype || input_rhs_is_qtype != output_is_qtype) { @@ -1215,12 +1219,12 @@ LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite( // Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time // FP16 is supported, the accumulator type can be selected based on trade-off // between performance and accuracy. Set to FP32 by default. - TypeAttr acc_attr = average_etype.isa() + TypeAttr acc_attr = mlir::isa(average_etype) ? mlir::TypeAttr::get(rewriter.getF32Type()) : mlir::TypeAttr::get(rewriter.getIntegerType(32)); Value result; - if (average_etype.isa()) { + if (mlir::isa(average_etype)) { // TensorFlow Lite doesn't use the zero point when calculating // quantized average pool, while TOSA does. Force the TOSA // zero_points to zero to ensure that the calculations match @@ -1445,11 +1449,11 @@ LogicalResult ConvertTFLConv2DOp::matchAndRewrite( if (!filter_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); bool filter_is_qtype = - filter_type.getElementType().isa(); + mlir::isa(filter_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(output_type.getElementType()); if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { @@ -1499,7 +1503,7 @@ LogicalResult ConvertTFLConv2DOp::matchAndRewrite( output_is_qtype ? rewriter.getI32Type() : output_type.getElementType(); if (unquantized_bias) { Type new_bias_ety = getElementTypeOrSelf(unquantized_bias.getType()); - if (auto qtype = new_bias_ety.dyn_cast()) { + if (auto qtype = mlir::dyn_cast(new_bias_ety)) { new_bias_ety = qtype.getStorageType(); } if (new_bias_ety.getIntOrFloatBitWidth() > @@ -1555,11 +1559,11 @@ LogicalResult ConvertTFLConv3DOp::matchAndRewrite( } bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); bool filter_is_qtype = - filter_type.getElementType().isa(); + mlir::isa(filter_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(output_type.getElementType()); if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { @@ -1578,7 +1582,7 @@ LogicalResult ConvertTFLConv3DOp::matchAndRewrite( RankedTensorType::get({bias_dim}, filter_type.getElementType()); auto bias_attr = rewriter.getZeroAttr(bias_type); unquantized_bias = CreateOpAndInfer( - rewriter, op->getLoc(), bias_type, bias_attr.cast()); + rewriter, op->getLoc(), bias_type, mlir::cast(bias_attr)); } SmallVector strides({tfl_conv3d_op.getStrideD(), @@ -1588,7 +1592,7 @@ LogicalResult ConvertTFLConv3DOp::matchAndRewrite( tfl_conv3d_op.getDilationHFactor(), tfl_conv3d_op.getDilationWFactor()}); Type bias_ety = - unquantized_bias.getType().cast().getElementType(); + mlir::cast(unquantized_bias.getType()).getElementType(); std::optional a1_conv3d_op = convertConv3DCommon( rewriter, op, output_type.clone(bias_ety), tfl_conv3d_op.getInput(), tfl_conv3d_op.getFilter(), unquantized_bias, strides, dilations, @@ -1634,11 +1638,11 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( if (!filter_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); bool filter_is_qtype = - filter_type.getElementType().isa(); + mlir::isa(filter_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(output_type.getElementType()); if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { @@ -1721,7 +1725,7 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( } if (!zero_bias) return failure(); - Type bias_ety = zero_bias->getType().cast().getElementType(); + Type bias_ety = mlir::cast(zero_bias->getType()).getElementType(); auto a1_conv2d_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type.clone(bias_ety), @@ -1770,11 +1774,11 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( if (!filter_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); bool filter_is_qtype = - filter_type.getElementType().isa(); + mlir::isa(filter_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(output_type.getElementType()); if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { @@ -1863,7 +1867,7 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( Value unquantized_bias = tfl_conv2d_op.getBias(); if (unquantized_bias) { Type new_bias_ety = getElementTypeOrSelf(unquantized_bias.getType()); - if (auto qtype = new_bias_ety.dyn_cast()) { + if (auto qtype = mlir::dyn_cast(new_bias_ety)) { new_bias_ety = qtype.getStorageType(); } if (new_bias_ety.getIntOrFloatBitWidth() > @@ -1906,7 +1910,7 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_mm_op = cast(op); - auto result_ty = tfl_mm_op.getType().cast(); + auto result_ty = mlir::cast(tfl_mm_op.getType()); Value lhs = tfl_mm_op.getX(); Value rhs = tfl_mm_op.getY(); RankedTensorType lhs_ty = dyn_cast(lhs.getType()); @@ -1916,10 +1920,12 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( if (!lhs_ty || !rhs_ty) return failure(); - bool lhs_is_qtype = lhs_ty.getElementType().isa(); - bool rhs_is_qtype = rhs_ty.getElementType().isa(); + bool lhs_is_qtype = + mlir::isa(lhs_ty.getElementType()); + bool rhs_is_qtype = + mlir::isa(rhs_ty.getElementType()); bool result_is_qtype = - result_ty.getElementType().isa(); + mlir::isa(result_ty.getElementType()); if ((lhs_is_qtype != rhs_is_qtype) || (lhs_is_qtype != result_is_qtype)) { return rewriter.notifyMatchFailure( @@ -1951,8 +1957,8 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( rewriter, op->getLoc(), UnrankedTensorType::get(rhs_ty.getElementType()), rhs, rewriter.getDenseI64ArrayAttr(new_rhs_shape)); - lhs_ty = lhs.getType().cast(); - rhs_ty = rhs.getType().cast(); + lhs_ty = mlir::cast(lhs.getType()); + rhs_ty = mlir::cast(rhs.getType()); } if (transpose_lhs) { @@ -1977,12 +1983,12 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( Type output_ety; if (result_is_qtype) { - auto lhs_qty_width = lhs_ty.getElementType() - .cast() - .getStorageTypeIntegralWidth(); - auto rhs_qty_width = rhs_ty.getElementType() - .cast() - .getStorageTypeIntegralWidth(); + auto lhs_qty_width = + mlir::cast(lhs_ty.getElementType()) + .getStorageTypeIntegralWidth(); + auto rhs_qty_width = + mlir::cast(rhs_ty.getElementType()) + .getStorageTypeIntegralWidth(); if (lhs_qty_width != rhs_qty_width) { return rewriter.notifyMatchFailure( @@ -2001,13 +2007,13 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( output_ety = result_ty.getElementType(); } - auto matmul = + Value matmul = CreateOpAndInfer( rewriter, op->getLoc(), UnrankedTensorType::get(output_ety), lhs, rhs) .getResult(); // Conditionally reshape rank back to expected rank. - auto matmul_ty = matmul.getType().cast(); + auto matmul_ty = mlir::cast(matmul.getType()); if (batch_dims.size() != 1) { llvm::SmallVector new_shape{}; for (auto d : batch_dims) { @@ -2052,11 +2058,11 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( if (!input_type || !filter_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); bool filter_is_qtype = - filter_type.getElementType().isa(); + mlir::isa(filter_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(output_type.getElementType()); if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { @@ -2099,7 +2105,7 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( RankedTensorType new_bias_type; DenseElementsAttr bias_attr; - if (input_type.getElementType().isa()) { + if (mlir::isa(input_type.getElementType())) { SmallVector bias_arr(bias_shape[0]); for (int i = 0; i < bias_shape[0]; i++) { @@ -2120,7 +2126,7 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( op, "input must be quantized type if it's not float type"); } auto input_qtype = - input_type.getElementType().cast(); + mlir::cast(input_type.getElementType()); Type new_bias_ety = input_qtype.getStorageTypeIntegralWidth() == 16 ? rewriter.getIntegerType(48) : rewriter.getI32Type(); @@ -2136,7 +2142,7 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( bias_val = tfl_fc_op.getBias(); } - Type bias_ety = bias_val.getType().cast().getElementType(); + Type bias_ety = mlir::cast(bias_val.getType()).getElementType(); auto fc_op = CreateOpAndInfer( rewriter, op->getLoc(), UnrankedTensorType::get(bias_ety), input_val, @@ -2152,7 +2158,7 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( } // If we know the output rank, we need to ensure the output shape is correct. - ShapedType fc_type = fc_output.getType().cast(); + ShapedType fc_type = mlir::cast(fc_output.getType()); if (output_type.hasRank()) { llvm::SmallVector output_shape; @@ -2270,7 +2276,7 @@ LogicalResult ConvertTFLRankOp::matchAndRewrite( RankedTensorType::get({1}, rewriter.getIntegerType(32)); auto rank_attr = DenseI32ArrayAttr::get(rewriter.getContext(), {rank}); auto rank_const = CreateOpAndInfer( - rewriter, op->getLoc(), rank_type, rank_attr.cast()); + rewriter, op->getLoc(), rank_type, mlir::cast(rank_attr)); rewriter.replaceOp(op, {rank_const.getResult()}); @@ -2303,7 +2309,7 @@ LogicalResult ConvertTFLShapeOp::matchAndRewrite( auto shape_attr = DenseI32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(shape_arr)); auto shape_const = CreateOpAndInfer( - rewriter, op->getLoc(), shape_type, shape_attr.cast()); + rewriter, op->getLoc(), shape_type, mlir::cast(shape_attr)); rewriter.replaceOp(op, {shape_const.getResult()}); @@ -2376,7 +2382,7 @@ LogicalResult ConvertTFLFillOp::matchAndRewrite( DenseArrayAttr fill_attr; // Convert to a compatible zero type. - if (value_elem.getShapedType().getElementType().isa()) { + if (mlir::isa(value_elem.getShapedType().getElementType())) { SmallVector fill_arr( total_size, value_elem.getValues()[0].convertToFloat()); fill_attr = @@ -2388,7 +2394,7 @@ LogicalResult ConvertTFLFillOp::matchAndRewrite( DenseI32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(fill_arr)); } auto fill_const_op = CreateOpAndInfer( - rewriter, op->getLoc(), fill_type, fill_attr.cast()); + rewriter, op->getLoc(), fill_type, mlir::cast(fill_attr)); rewriter.replaceOp(op, {fill_const_op.getResult()}); return success(); @@ -2589,11 +2595,11 @@ LogicalResult ConvertTFLRsqrtOp::matchAndRewrite( dyn_cast(tfl_rsqrt_op.getX().getType()); mlir::quant::UniformQuantizedType input_qtype = - input_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + input_type.getElementType()); mlir::quant::UniformQuantizedType output_qtype = - output_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + output_type.getElementType()); // Quantization case if (input_qtype && output_qtype) { @@ -2636,7 +2642,7 @@ LogicalResult ConvertTFLL2NormalizationOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_l2norm_op = cast(op); auto input = tfl_l2norm_op.getInput(); - auto input_ty = input.getType().cast(); + auto input_ty = mlir::cast(input.getType()); auto loc = op->getLoc(); if (!input_ty.hasRank()) return failure(); @@ -3200,15 +3206,15 @@ LogicalResult ConvertTFLHardSwishOp::matchAndRewrite( // TFL hardswish: f(x) -> (x * relu6(x+3))/6 - if (input_type.getElementType().isa() && - output_type.getElementType().isa()) { + if (mlir::isa(input_type.getElementType()) && + mlir::isa(output_type.getElementType())) { // Should match TFLite reference numerical behavior mlir::quant::UniformQuantizedType input_qtype = - input_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + input_type.getElementType()); mlir::quant::UniformQuantizedType output_qtype = - output_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + output_type.getElementType()); auto hardswish_func = [](double v) -> double { double w = v + 3.0; @@ -3286,8 +3292,8 @@ LogicalResult ConvertTFLCosOp::matchAndRewrite( if (!input_ty || !output_ty) return failure(); - bool input_is_fp = input_ty.getElementType().isa(); - bool output_is_fp = output_ty.getElementType().isa(); + bool input_is_fp = mlir::isa(input_ty.getElementType()); + bool output_is_fp = mlir::isa(output_ty.getElementType()); if (!input_is_fp || !output_is_fp) { return rewriter.notifyMatchFailure(op, "input/result must be fp"); @@ -3440,9 +3446,9 @@ LogicalResult ConvertTFLLogisticOp::matchAndRewrite( if (!input_type || !output_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { return rewriter.notifyMatchFailure( @@ -3453,11 +3459,11 @@ LogicalResult ConvertTFLLogisticOp::matchAndRewrite( if (input_is_qtype) { ShapedType int32_type = output_type.clone(rewriter.getIntegerType(32)); mlir::quant::UniformQuantizedType input_qtype = - input_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + input_type.getElementType()); mlir::quant::UniformQuantizedType output_qtype = - output_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + output_type.getElementType()); auto sigmoid_func = [](double x) -> double { return 1.0 / (1.0 + std::exp(-x)); @@ -3511,9 +3517,9 @@ LogicalResult ConvertTFLTanhOp::matchAndRewrite( if (!input_type || !output_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { return rewriter.notifyMatchFailure( @@ -3524,11 +3530,11 @@ LogicalResult ConvertTFLTanhOp::matchAndRewrite( if (input_is_qtype) { ShapedType int32_type = output_type.clone(rewriter.getIntegerType(32)); mlir::quant::UniformQuantizedType input_qtype = - input_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + input_type.getElementType()); mlir::quant::UniformQuantizedType output_qtype = - output_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + output_type.getElementType()); auto tanh_func = [](double x) -> double { x = std::exp(-2.0 * x); @@ -3644,9 +3650,9 @@ static LogicalResult LegalizeQuantizedPrelu(Operation* op, // Perform an element-wise multiplication on rescaled alpha and input for // PReLU. Value alpha = tfl_prelu_op.getAlpha(); - ShapedType alpha_type = alpha.getType().cast(); + ShapedType alpha_type = mlir::cast(alpha.getType()); UniformQuantizedType alpha_qtype = - alpha_type.getElementType().cast(); + mlir::cast(alpha_type.getElementType()); Value op_rescale_alpha = removeZeroPointAndCastToInt32( rewriter, op, alpha, alpha_qtype.getZeroPoint()); @@ -3698,7 +3704,7 @@ static LogicalResult LegalizeQuantizedLeakyRelu(Operation* op, PatternRewriter& rewriter, Value input, double alpha, ShapedType output_type) { - ShapedType input_type = input.getType().cast(); + ShapedType input_type = mlir::cast(input.getType()); ShapedType rescale_type = input_type.clone(rewriter.getI32Type()); UniformQuantizedType input_qtype = @@ -3784,9 +3790,9 @@ LogicalResult ConvertTFLLeakyReluOp::matchAndRewrite( "input or output is not a ShapedType"); bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { return rewriter.notifyMatchFailure( @@ -3846,8 +3852,7 @@ LogicalResult ConvertTFLCustomOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, op->getResultTypes(), tfl_custom_op.getCustomCode(), rewriter.getStringAttr("TFL"), - tfl_custom_op.getCustomOption() - .cast() + mlir::cast(tfl_custom_op.getCustomOption()) .getValue() .str(), op->getOperands()); @@ -3966,7 +3971,7 @@ LogicalResult ConvertTFLDequantizeOp::matchAndRewrite( if (!qtype) return failure(); Type element_type = qtype.getElementType(); - if (element_type.isa()) { + if (mlir::isa(element_type)) { CreateReplaceOpAndInfer(rewriter, op, output_type, tfl_dequantize_op.getInput()); return success(); @@ -4023,7 +4028,7 @@ LogicalResult ConvertTFLConstOp::matchAndRewrite( ElementsAttr elements = tfl_const_op.getValue(); Type element_type = elements.getShapedType().getElementType(); - if (output_type.getElementType().isa()) { + if (mlir::isa(output_type.getElementType())) { output_type = RankedTensorType::get(output_type.getShape(), element_type); } @@ -4031,7 +4036,8 @@ LogicalResult ConvertTFLConstOp::matchAndRewrite( // attribute shape. This occurs as some TFLite folders create constants with // unranked shapes. if (!output_type.hasRank()) { - output_type = elements.getType().cast().clone(element_type); + output_type = + mlir::cast(elements.getType()).clone(element_type); } rewriter.replaceOpWithNewOp(op, output_type, elements); @@ -4053,8 +4059,8 @@ LogicalResult ConvertTFLQConstOp::matchAndRewrite( // attribute shape. This occurs as some TFLite folders create constants with // unranked shapes. if (!output_type.hasRank()) { - output_type = elements.getType().cast().clone( - output_type.getElementType()); + output_type = mlir::cast(elements.getType()) + .clone(output_type.getElementType()); } rewriter.replaceOpWithNewOp(op, output_type, elements); @@ -4079,7 +4085,7 @@ LogicalResult ConvertConstantOp::matchAndRewrite( // For data type like 64 bits, we need to truncate them into 48 bits. if (e_type.isInteger(64)) { e_type = rewriter.getIntegerType(48); - attr = attr.cast().mapValues( + attr = mlir::cast(attr).mapValues( e_type, [](const APInt& x) -> APInt { return x.trunc(48); }); } @@ -4136,11 +4142,11 @@ LogicalResult ConvertTFLSparseToDenseOp::matchAndRewrite( auto indices = tfl_sparse_to_dense_op.getSparseIndices(); auto values = tfl_sparse_to_dense_op.getSparseValues(); auto default_value = tfl_sparse_to_dense_op.getDefaultValue(); - auto indices_ty = indices.getType().cast(); + auto indices_ty = mlir::cast(indices.getType()); auto indices_ety = indices_ty.getElementType(); - auto values_ty = values.getType().cast(); + auto values_ty = mlir::cast(values.getType()); auto result_ty = - tfl_sparse_to_dense_op.getResult().getType().cast(); + mlir::cast(tfl_sparse_to_dense_op.getResult().getType()); auto result_ety = result_ty.getElementType(); auto loc = op->getLoc(); @@ -4262,7 +4268,7 @@ LogicalResult ConvertTFLArgMinOp::matchAndRewrite( auto arg_max_op = cast(op); auto loc = arg_max_op.getLoc(); auto input = arg_max_op.getInput(); - auto input_ty = input.getType().cast(); + auto input_ty = mlir::cast(input.getType()); Type input_ety = input_ty.getElementType(); if (auto quantized_ty = dyn_cast(input_ety)) { @@ -4281,9 +4287,9 @@ LogicalResult ConvertTFLArgMinOp::matchAndRewrite( int32_t dim = dim_elems.getValues()[0].getSExtValue(); if (dim < 0) dim += input_ty.getRank(); - if (input_ety.isa()) { + if (mlir::isa(input_ety)) { input = CreateOpAndInfer(rewriter, loc, input_ty, input); - } else if (input_ety.isa()) { + } else if (mlir::isa(input_ety)) { auto reverse_ty = RankedTensorType::get({}, input_ety); Value reverse_val = rewriter.create( loc, reverse_ty, @@ -4370,12 +4376,12 @@ LogicalResult ConvertTFLRealOp::matchAndRewrite( Type input_ety = input_ty.getElementType(); // For non-complex inputs, return the original tensor. - if (!input_ety.isa()) { + if (!mlir::isa(input_ety)) { CreateReplaceOpAndInfer(rewriter, op, input_ty, input); return success(); } - if (!input_ety.cast().getElementType().isF32()) { + if (!mlir::cast(input_ety).getElementType().isF32()) { return rewriter.notifyMatchFailure( op, "complex input must be of type complex64"); } @@ -4425,13 +4431,13 @@ LogicalResult ConvertTFLImagOp::matchAndRewrite( Type input_ety = input_ty.getElementType(); // For non-complex inputs return all zero's. - if (!input_ety.isa()) { + if (!mlir::isa(input_ety)) { CreateReplaceOpAndInfer( rewriter, op, input_ty, DenseElementsAttr::get(input_ty, {0.0f})); return success(); } - if (!input_ety.cast().getElementType().isF32()) { + if (!mlir::cast(input_ety).getElementType().isF32()) { return rewriter.notifyMatchFailure( op, "complex input must be of type complex64"); } diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index de8e777a7d558e..8571995d719484 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" @@ -78,7 +79,7 @@ std::optional buildReshapeWithDynamicDims(PatternRewriter& rewriter, Value input_value, ShapedType output_type, llvm::ArrayRef dims) { - auto e_ty = input_value.getType().cast().getElementType(); + auto e_ty = mlir::cast(input_value.getType()).getElementType(); llvm::SmallVector static_dims; if (output_type.hasRank()) { @@ -92,7 +93,7 @@ std::optional buildReshapeWithDynamicDims(PatternRewriter& rewriter, auto dim = dims[i]; SplatElementsAttr dim_attr; if (matchPattern(dim, m_Constant(&dim_attr))) { - if (dim_attr.getType().cast().getRank() != 0) { + if (mlir::cast(dim_attr.getType()).getRank() != 0) { (void)rewriter.notifyMatchFailure( op, "dim for building tosa::ReshapeOp should be rank-0"); return std::nullopt; @@ -643,8 +644,8 @@ DenseI64ArrayAttr getPaddingValuesFromExplicitPadAttr( for (int i = 0; i < 2; i++) { // Two spatial dimensions X&Y int64_t dim = GetTensorSpatialDimIndex(4, data_format_tf, i); // 4D tensor, NHWC/NCHW format - pad_before = explicit_pad[dim * 2].template cast().getInt(); - pad_after = explicit_pad[dim * 2 + 1].template cast().getInt(); + pad_before = mlir::cast(explicit_pad[dim * 2]).getInt(); + pad_after = mlir::cast(explicit_pad[dim * 2 + 1]).getInt(); computed_paddings.push_back(pad_before); computed_paddings.push_back(pad_after); } @@ -801,11 +802,11 @@ LogicalResult ApplyPatternsWithShapeResolution( // This should be investigate for whether it is still necessary due to quant // type stripping changing. func.walk([&](tosa::ConstOp op) { - if (op.getType().getElementType().isa()) { + if (mlir::isa(op.getType().getElementType())) { return; } auto ety = op.getValue().getShapedType().getElementType(); - auto new_ty = op.getType().cast().clone(ety); + auto new_ty = mlir::cast(op.getType()).clone(ety); op.getResult().setType(new_ty); }); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h index d2e04ac869ae48..acb9dff2a4a8ff 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h @@ -202,7 +202,7 @@ TosaOp CreateOpAndInfer(ImplicitLocOpBuilder& builder, Type result_ty, // Compute the knowledge based on the inferred type. auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); - inferredKnowledge.dtype = result_ty.cast().getElementType(); + inferredKnowledge.dtype = mlir::cast(result_ty).getElementType(); inferredKnowledge.hasRank = predictedShape.hasRank(); if (predictedShape.hasRank()) { for (auto dim : predictedShape.getDims()) { diff --git a/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc b/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc index 987ac5deb7479f..765cf33aa08812 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc @@ -42,6 +42,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -115,7 +116,7 @@ class GenericTypeConvert : public ConversionPattern { static bool isIllegalType(Type type) { if (auto shapedType = dyn_cast(type)) { - return shapedType.getElementType().isa(); + return mlir::isa(shapedType.getElementType()); } return false; } diff --git a/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc b/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc index 85df18855769fc..11857f3b1c3404 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" @@ -121,7 +122,7 @@ class GenericTypeConvert : public ConversionPattern { }; static bool isIllegalType(Type type) { - if (type.isa()) return true; + if (mlir::isa(type)) return true; if (auto shapedType = dyn_cast(type)) { return isIllegalType(shapedType.getElementType()); } diff --git a/tensorflow/compiler/mlir/utils/BUILD b/tensorflow/compiler/mlir/utils/BUILD index e34e9cf7be7cca..2256c421b45717 100644 --- a/tensorflow/compiler/mlir/utils/BUILD +++ b/tensorflow/compiler/mlir/utils/BUILD @@ -16,6 +16,7 @@ cc_library( deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/utils/name_utils.cc b/tensorflow/compiler/mlir/utils/name_utils.cc index 6ca366fc9d64d5..7ce1c46861c2bb 100644 --- a/tensorflow/compiler/mlir/utils/name_utils.cc +++ b/tensorflow/compiler/mlir/utils/name_utils.cc @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { @@ -63,7 +64,7 @@ std::string GetNameFromLoc(Location loc) { while (!locs.empty()) { Location curr_loc = locs.pop_back_val(); - if (auto name_loc = curr_loc.dyn_cast()) { + if (auto name_loc = mlir::dyn_cast(curr_loc)) { // Add name in NameLoc. For NameLoc we also account for names due to ops // in functions where the op's name is first. auto name = name_loc.getName().strref().split('@').first; @@ -73,11 +74,11 @@ std::string GetNameFromLoc(Location loc) { if (!name.empty()) names_is_nonempty = true; } continue; - } else if (auto call_loc = curr_loc.dyn_cast()) { + } else if (auto call_loc = mlir::dyn_cast(curr_loc)) { // Use location of the Callee to generate the name. locs.push_back(call_loc.getCallee()); continue; - } else if (auto fused_loc = curr_loc.dyn_cast()) { + } else if (auto fused_loc = mlir::dyn_cast(curr_loc)) { // Push all locations in FusedLoc in reverse order, so locations are // visited based on order in FusedLoc. auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations()); diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py index 94589f47c05d1b..60a7949138e9c1 100644 --- a/tensorflow/compiler/tests/gather_nd_op_test.py +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -17,7 +17,6 @@ import numpy as np from tensorflow.compiler.tests import xla_test -from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -43,7 +42,7 @@ def testSimpleDtype(self): np.array([[4], [4], [0]], np.int32))) @test_util.disable_mlir_bridge("Error handling") - def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self): + def testEmptyIndicesAndParamstAndEmptyParamsOk(self): with self.session(): params = np.ones((3, 3), dtype=np.float32) @@ -60,11 +59,11 @@ def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self): gather_nd_ok_val = self._runGather(params_empty, indices_empty) self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val) + # Zero sized indices results in a constant of 0 params_empty = np.empty((0, 3), dtype=np.float32) indices_nonempty = np.zeros((1, 2), dtype=np.int32) - with self.assertRaisesWithPredicateMatch( - errors.InvalidArgumentError, r"Gather dimension 0 is of size zero"): - self._runGather(params_empty, indices_nonempty) + gather_nd_ok_val = self._runGather(params_empty, indices_nonempty) + self.assertAllEqual(gather_nd_ok_val, np.zeros([3])) def testIndexScalar(self): params = np.array( diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 46f192648ecaa6..32ec425c02790a 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -1306,7 +1306,53 @@ def testRngBitGenerator(self, algorithm, dtype): with self.assertRaisesRegex( TypeError, 'Failed to convert elements .* to Tensor' ): - res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype) + _ = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype) + + def testGatherShapeInference(self): + operand = np.arange(10, dtype=np.int32).reshape([2, 5]) + start_indices = np.array([2], np.int32) + slice_sizes = np.array([1, 3], np.int32) + dimension_numbers = xla_data_pb2.GatherDimensionNumbers( + offset_dims=[1], + collapsed_slice_dims=[0], + start_index_map=[0], + index_vector_dim=1, + ) + + res = xla.gather(operand, start_indices, dimension_numbers, slice_sizes) + self.assertEqual(res.shape, tensor_shape.TensorShape([1, 3])) + + def testGatherShapeInferenceDynamicSlice(self): + operand = np.arange(12, dtype=np.int32).reshape([3, 2, 2]) + start_indices = array_ops.placeholder(np.int32, shape=(3, None, 2)) + slice_sizes = np.array([1, 2, 2], np.int32) + dimension_numbers = xla_data_pb2.GatherDimensionNumbers( + offset_dims=[2, 3], + collapsed_slice_dims=[0], + start_index_map=[0, 1], + index_vector_dim=2, + ) + + res = xla.gather(operand, start_indices, dimension_numbers, slice_sizes) + self.assertEqual(res.shape, tensor_shape.TensorShape([3, None, 2, 2])) + + def testGatherShapeInferenceDynamicInput(self): + operand = array_ops.placeholder(np.int32, shape=(None, 5)) + start_indices = np.array([2], np.int32) + slice_sizes = np.array([1, 3], np.int32) + dimension_numbers = xla_data_pb2.GatherDimensionNumbers() + + res = xla.gather(operand, start_indices, dimension_numbers, slice_sizes) + self.assertEqual(res.shape, tensor_shape.unknown_shape()) + + def testGatherShapeInferenceUnknownSliceSizes(self): + operand = np.arange(10, dtype=np.int32).reshape([2, 5]) + start_indices = np.array([2], np.int32) + slice_sizes = array_ops.placeholder(np.int32, shape=(2,)) + dimension_numbers = xla_data_pb2.GatherDimensionNumbers() + + res = xla.gather(operand, start_indices, dimension_numbers, slice_sizes) + self.assertEqual(res.shape, tensor_shape.unknown_shape()) if __name__ == '__main__': diff --git a/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc index 83d5f9b59656ed..cc02348fc86726 100644 --- a/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" - #include #if GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 01e85cc7c6cfc7..c1a2243de20146 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -54,6 +54,7 @@ package_group( "//third_party/mlperf/submissions/training/v0_7/models/...", "//third_party/py/keras_cv/...", "//third_party/py/tf_keras/...", + "//third_party/sparse_conv/ops/...", "//waymo/ml/deploy/benchmark/...", ], ) @@ -218,7 +219,6 @@ filegroup( "@local_tsl//tsl/framework/fixedpoint:xla_cpu_runtime_hdrs", "@local_tsl//tsl/platform:xla_cpu_runtime_srcs", "@local_xla//xla:cpu_runtime_hdrs", - "@local_xla//xla/runtime:aot_ffi_execution_context_hdrs", "@local_xla//xla/service:custom_call_status_hdrs", "@local_xla//xla/service/cpu:runtime_hdrs", ], @@ -391,7 +391,6 @@ cc_library( # binary produced by tfcompile. "@local_xla//xla:cpu_function_runtime", "@local_xla//xla:executable_run_options", - "@local_xla//xla/runtime:aot_ffi_execution_context", "@local_xla//xla/service/cpu:buffer_desc", "//tensorflow/core/platform:types", ], @@ -500,6 +499,7 @@ cc_library( "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/jit:xla_compile_util", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", "//tensorflow/compiler/mlir/tf2xla/api/v1:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/mlir/utils:array_container_utils", @@ -706,7 +706,7 @@ cc_library( "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/service:computation_placer_hdr", "@local_xla//xla/service/gpu:gpu_executable_run_options", - "@local_xla//xla/service/gpu:nccl_clique_key", + "@local_xla//xla/service/gpu/runtime:nccl_clique_key", "@local_xla//xla/stream_executor", "@local_xla//xla/translate/mhlo_to_hlo:layout_util", ], diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 6a60149d7cc4a1..31e227970083e6 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -297,6 +297,7 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", "//tensorflow/core:lib", + "@com_google_absl//absl/log", "@local_xla//xla:literal_util", "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 5877aea0269643..8087b271ba5fe2 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -91,9 +91,14 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, for (int64_t i = 0; i < num_index_dims; ++i) { if (input_shape.dim_size(axis + i) == 0) { - return errors::InvalidArgument("Gather dimension ", axis + i, - " is of size zero in tensor with shape ", - input_shape.DebugString()); + // Gather dimension of size zero in tensor results in constant 0. + // This is done to match the legacy behavior of the MLIR legalization and + // avoid breaking existing models. + auto slice_sizes = input_shape.dim_sizes(); + slice_sizes.erase(slice_sizes.begin() + axis); + *gather_output = + xla::Broadcast(XlaHelpers::Zero(builder, dtype), slice_sizes); + return absl::OkStatus(); } } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index d7a1b5f970561a..c5229072b56429 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/log/log.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "xla/client/xla_builder.h" #include "xla/literal_util.h" @@ -492,6 +493,27 @@ Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, start_indices[0] = index; xla::XlaOp list_part = xla::GetTupleElement(list, 0); + { + TF_ASSIGN_OR_RETURN(const xla::Shape* list_part_shape, + b->GetShapePtr(list_part)); + TF_ASSIGN_OR_RETURN(const xla::Shape* update_shape, b->GetShapePtr(update)); + for (int i = 0; i < list_part_shape->dimensions_size(); ++i) { + auto list_part_dim_size = list_part_shape->dimensions(i); + auto update_dim_size = update_shape->dimensions(i); + // If the update is larger than the list part, the DynamicUpdateSlice will + // fail so just ignore this operation and return list as is. + if (update_dim_size > list_part_dim_size) { + LOG_FIRST_N(WARNING, 1) + << "Warning: TensorListSetItem: ignoring set item because the " + "update dim [" + << update_dim_size << "] is larger than the list dim [" + << list_part_dim_size << "] at dimension " << i << "."; + + *result = list; + return absl::OkStatus(); + } + } + } xla::XlaOp updated_list_part = xla::DynamicUpdateSlice(list_part, update, start_indices); diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index 5c19b9fe1014d3..39d4b086788ffe 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -116,8 +116,8 @@ constexpr llvm::StringRef kCustomCallShimTarget = } // namespace bool IsTokenType(mlir::Type type) { - return type.isa() || - type.isa(); + return mlir::isa(type) || + mlir::isa(type); } absl::StatusOr> @@ -174,7 +174,7 @@ absl::Status XlaCallModuleLoader::SetPlatformIndex( op_builder.setInsertionPointToStart(&main_body); mlir::BlockArgument platform_index_arg = main_body.getArgument(0); mlir::RankedTensorType arg_ranked_type = - platform_index_arg.getType().dyn_cast(); + mlir::dyn_cast(platform_index_arg.getType()); if (!arg_ranked_type || arg_ranked_type.getRank() != 0 || !(arg_ranked_type.getElementType().isSignlessInteger(32) || arg_ranked_type.getElementType().isSignlessInteger(64))) { @@ -301,7 +301,7 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( << mlir::debugString(type) << " for argument type " << mlir::debugString(arg_type); mlir::TensorType arg_type = - main_body.getArgument(i).getType().dyn_cast(); + mlir::dyn_cast(main_body.getArgument(i).getType()); if (arg_type == nullptr) { return absl::InvalidArgumentError(absl::StrCat( "Argument ", i, " passed to XlaCallModule is not a tensor, ", @@ -316,7 +316,8 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( mlir::debugString(arg_type), ", got ", mlir::debugString(type))); } - if (auto ranked_arg_type = arg_type.dyn_cast()) { + if (auto ranked_arg_type = + mlir::dyn_cast(arg_type)) { if (mlir::failed(mlir::verifyCompatibleShape(ranked_arg_type.getShape(), type.getShape()))) { return absl::InvalidArgumentError(absl::StrCat( @@ -380,9 +381,10 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( if (IsTokenType(arg_type) || is_input_refined) { continue; } - auto ranked_arg_type = arg_type.dyn_cast(); + auto ranked_arg_type = mlir::dyn_cast(arg_type); if (!ranked_arg_type || !ranked_arg_type.hasStaticShape()) { - auto type = static_array_input_types[i].cast(); + auto type = + mlir::cast(static_array_input_types[i]); auto custom_call = MakeShapeRefinementOperandWrapper(op_builder, arg, type.getShape()); auto call_result = custom_call.getResult(0); @@ -409,8 +411,8 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( // Clean up custom_call shims. for (auto call : llvm::make_early_inc_range( main_body.getOps())) { - if (call->getAttr("call_target_name").cast().strref() == - kCustomCallShimTarget) { + if (mlir::cast(call->getAttr("call_target_name")) + .strref() == kCustomCallShimTarget) { auto operand = call->getOperand(0); auto result = call->getResult(0); if (operand.getType() != result.getType()) { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc index dbba8e7d8afd6a..529f27a0f7b25d 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h" @@ -403,7 +404,7 @@ class XlaCallModuleOp : public XlaOpKernel { mlir::TypeRange input_types(custom_call->getOperandTypes()); if (custom_call_has_token_input_output) { if (input_types.empty() || - !input_types.front().isa()) { + !mlir::isa(input_types.front())) { return absl::InvalidArgumentError(absl::StrCat( "stablehlo.custom_call with has_token_input_output = true is " "expected to take !stablehlo.token as the first argument, but " @@ -422,7 +423,7 @@ class XlaCallModuleOp : public XlaOpKernel { mlir::TypeRange result_types(custom_call->getResultTypes()); if (custom_call_has_token_input_output) { if (result_types.empty() || - !result_types.front().isa()) { + !mlir::isa(result_types.front())) { return absl::InvalidArgumentError(absl::StrCat( "stablehlo.custom_call with has_token_input_output = true is " "expected to return !stablehlo.token as the first result, but " diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index d1a2a68d045bfc..9994f3ae2e5c56 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -208,9 +208,21 @@ MlirOptimizationPassState MlirBridgePass::GetPassState( return MlirOptimizationPassState::Disabled; } + // TODO(b/328084279): when MlirBridgePass::GetPassState() returns + // MlirOptimizationPassState::FallbackEnabled or + // MlirOptimizationPassState::Enabled, Tensorflow imports a Graph to an + // MLIR module, calls MlirBridgePass::Run(), and exports the MLIR module to a + // Graph. The Graph->MLIR module->Graph round trip will not happen if + // MlirOptimizationPassState::Disabled is returned. Some input graphs with a + // TPU device in device_set yet without replication depends on the round + // trip, which does not always produce the same Graph. Call + // HasTPUDevice(*device_set) to ensure such graps work. Note + // MlirBridgePass::Run() will still reject such graphs that they do not go + // through the Phase 1 Bridge. return GetPassStateImpl( /*is_supported_by_replicated_brige*/ IsSupportedByReplicatedBridge( - graph, &function_library), + graph, &function_library) || + HasTPUDevice(*device_set), config_proto, graph, function_library); } diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index 6adab4c6c7f6b4..61a329aa46f1ba 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -18,9 +18,15 @@ cc_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/service:shape_inference", ], alwayslink = 1, ) @@ -47,6 +53,8 @@ tf_custom_op_library( ], deps = [ "@com_google_absl//absl/algorithm:container", + "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/service:shape_inference", ], ) diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 27a534296921cd..a51ad205015bad 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -13,18 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include +#include #include #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" +#include "absl/strings/str_join.h" +#include "xla/service/shape_inference.h" +#include "xla/shape.h" #include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" // Note: Most of the operators defined in this module are used by the jax2tf // converter (see go/jax2tf for details) and are used in SavedModel produced @@ -1123,6 +1136,26 @@ REGISTER_OP("XlaReplicaId") }) .Doc("Replica ID."); +xla::Shape GetShape(shape_inference::ShapeHandle shape_handle, + shape_inference::InferenceContext* c) { + if (!c->RankKnown(shape_handle)) { + return xla::Shape(); + } + std::vector dims; + std::vector dynamic_dims; + for (int i = 0, rank = c->Rank(shape_handle); i < rank; ++i) { + bool is_dynamic = !c->ValueKnown(c->Dim(shape_handle, i)); + dynamic_dims.push_back(is_dynamic); + dims.push_back(is_dynamic ? xla::Shape::kUnboundedSize + : c->Value(c->Dim(shape_handle, i))); + } + return xla::Shape( + // Type matters only for indices. S64 is the widest possible type. + xla::PrimitiveType::S64, dims, + absl::InlinedVector(dynamic_dims.begin(), dynamic_dims.end()), + /*tuple_shapes=*/{}); +} + REGISTER_OP("XlaGather") .Input("operand: T") .Input("start_indices: Tindices") @@ -1132,7 +1165,63 @@ REGISTER_OP("XlaGather") .Attr("T: {numbertype, bool}") .Attr("Tindices: {int32, int64}") .Output("output: T") - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { + std::string dimension_numbers; + TF_RETURN_IF_ERROR(c->GetAttr("dimension_numbers", &dimension_numbers)); + xla::GatherDimensionNumbers gather_dim_numbers; + if (!gather_dim_numbers.ParseFromString(dimension_numbers)) { + return absl::InvalidArgumentError("Failed to parse dimension_numbers."); + } + VLOG(3) << c->DebugString(); + VLOG(3) << "dim_numbers: " << gather_dim_numbers.DebugString(); + VLOG(3) << "Shapes: operand: " << c->DebugString(c->input(0)) + << ", start_indices: " << c->DebugString(c->input(1)) + << ", slice_sizes: " << c->DebugString(c->input(2)); + + xla::Shape input_shape = GetShape(c->input(0), c); + xla::Shape start_indices_shape = GetShape(c->input(1), c); + xla::Shape slice_sizes_shape = GetShape(c->input(2), c); + + const Tensor* slice_sizes_tensor = c->input_tensor(2); + if (input_shape == xla::Shape() || input_shape.is_unbounded_dynamic() || + start_indices_shape == xla::Shape() || + slice_sizes_shape == xla::Shape()) { + VLOG(3) << "output will be unranked due to unknown or dynamic input " + "shapes."; + return shape_inference::UnknownShape(c); + } + if (slice_sizes_tensor == nullptr || + slice_sizes_tensor->NumElements() == -1) { + VLOG(3) << "output will be unranked due to non-constant slice_sizes."; + return shape_inference::UnknownShape(c); + } + std::vector slice_sizes; + if (slice_sizes_tensor->dtype() == DT_INT32) { + for (int i = 0; i < slice_sizes_tensor->NumElements(); ++i) { + slice_sizes.push_back(slice_sizes_tensor->flat()(i)); + } + } else if (slice_sizes_tensor->dtype() == DT_INT64) { + for (int i = 0; i < slice_sizes_tensor->NumElements(); ++i) { + slice_sizes.push_back(slice_sizes_tensor->flat()(i)); + } + } + VLOG(3) << "slice_sizes [val]: " << absl::StrJoin(slice_sizes, ","); + TF_ASSIGN_OR_RETURN(xla::Shape output_shape, + xla::ShapeInference::InferGatherShape( + input_shape, start_indices_shape, + gather_dim_numbers, slice_sizes)); + std::vector dims; + for (int64_t i = 0; i < output_shape.rank(); ++i) { + if (output_shape.is_unbounded_dynamic_dimension(i)) { + dims.push_back(c->UnknownDim()); + } else { + dims.push_back(c->MakeDim(output_shape.dimensions(i))); + } + } + c->set_output(0, c->MakeShape(dims)); + VLOG(3) << "output: " << c->DebugString(c->output(0)); + return absl::OkStatus(); + }) .Doc(R"doc( Wraps the XLA Gather operator documented at https://www.tensorflow.org/xla/operation_semantics#gather diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index dc4109f52f96b6..2da560c23635ff 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -20,44 +20,12 @@ limitations under the License. #include #include "xla/cpu_function_runtime.h" -#include "xla/runtime/aot_ffi_execution_context.h" namespace tensorflow { -namespace { -// MemrefDesc's are part of the XLA Runtime ABI. Redefine them here (with a -// slightly different name to avoid confusion) because we cannot depend on -// XLA Runtime's headers. -// Note: this is an internal type, to be used exclusively in this file. -struct MemrefHolder { - MemrefHolder(const XlaCompiledCpuFunction::ShapeInfo& shape_info, - void* data_ptr) - : rank(shape_info.num_dimensions), data(data_ptr), offset(0) { - sizes.resize(shape_info.num_dimensions); - strides.resize(shape_info.num_dimensions); - int64_t multiplier = 1; - for (int i = shape_info.num_dimensions - 1; i >= 0; --i) { - int64_t size = shape_info.dimensions[i]; - sizes[i] = size; - strides[i] = multiplier; - multiplier *= size; - } - } - - unsigned rank = 0; - // Note: dtype is not needed here. - void* data = nullptr; - int64_t offset = 0; - std::vector sizes; - std::vector strides; -}; -} // namespace - XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, AllocMode alloc_mode) : raw_function_(static_data.raw_function_), - external_run_function_(static_data.external_run_function_), - cpu_executable_(static_data.cpu_executable_), result_index_(static_data.result_index_), buffer_table_(new void*[static_data.num_buffers_]), buffer_infos_(static_data.buffer_infos_), @@ -73,8 +41,7 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, variable_names_(static_data.variable_names_), result_names_(static_data.result_names_), program_shape_(static_data.program_shape_), - hlo_profile_printer_data_(static_data.hlo_profile_printer_data_), - use_xla_runtime_(static_data.use_xla_runtime_) { + hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) { bool allocate_entry_params = alloc_mode == AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS; // Allocate arg and temp buffers. @@ -92,94 +59,13 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, } } -bool XlaCompiledCpuFunction::RunXlaRuntime() { - size_t num_memref_args = num_args_ + num_results_; - std::vector memref_args; - memref_args.reserve(num_memref_args); - - size_t num_ptrs = 1; // execution context. - - // Append arguments. - for (int i = 0; i < num_args_; ++i) { - const ShapeInfo& shape_info = arg_shape_infos_[i]; - memref_args.emplace_back(shape_info, buffer_table_[arg_index_table_[i]]); - num_ptrs += 3 + 2 * shape_info.num_dimensions; - } - - // Append results. - for (int i = 0; i < num_results_; ++i) { - const ShapeInfo& shape_info = result_shape_infos_[i]; - memref_args.emplace_back(shape_info, buffer_table_[result_index_table_[i]]); - num_ptrs += 3 + 2 * shape_info.num_dimensions; - - // Point to this result from the "result" entry in the buffer table. - void** results = static_cast(buffer_table_[result_index_]); - results[i] = buffer_table_[result_index_table_[i]]; - } - - std::vector call_frame; - call_frame.resize(num_ptrs); - size_t ptr_index = 1; - for (const MemrefHolder& memref : memref_args) { - auto cast = [](const void* p) { return const_cast(p); }; - call_frame[ptr_index + 0] = cast(&memref.data); // memref.basePtr - call_frame[ptr_index + 1] = cast(&memref.data); // memref.data - call_frame[ptr_index + 2] = cast(&memref.offset); - unsigned rank = memref.rank; - for (int64_t d = 0; d < rank; ++d) { - call_frame[ptr_index + 3 + d] = cast(&memref.sizes[d]); - call_frame[ptr_index + 3 + d + rank] = cast(&memref.strides[d]); - } - ptr_index += 3 + 2 * rank; - } - - assert(num_ptrs == ptr_index); - - xla::runtime::aot::ExecutionContext execution_context; - execution_context.custom_call_data = &run_options_; - xla::runtime::aot::ExecutionContext* execution_context_ptr = - &execution_context; - call_frame[0] = &execution_context_ptr; - - auto xla_runtime_func = - reinterpret_cast(raw_function_); - xla_runtime_func(call_frame.data()); - if (execution_context.error) { - // No error support in XLA; dump error message to stderr. - std::cerr << "XLA AOT error: " << execution_context.error << ".\n"; - return false; - } - return true; -} - bool XlaCompiledCpuFunction::Run() { - if (use_xla_runtime_) { - return RunXlaRuntime(); - } - if (external_run_function_) { - std::vector descriptor_table = - MakeXlaRuntimeDescriptorTable(); - return external_run_function_(cpu_executable_, descriptor_table, - &run_options_); - } XlaCustomCallStatus status; raw_function_(buffer_table_[result_index_], &run_options_, nullptr, buffer_table_, &status, profile_counters_); return !xla::CustomCallStatusGetMessage(&status).has_value(); } -std::vector -XlaCompiledCpuFunction::MakeXlaRuntimeDescriptorTable() { - std::vector descriptor_table; - descriptor_table.reserve(num_buffers_); - for (int32_t i = 0; i < num_buffers_; ++i) { - void* data = buffer_table_[i]; - uint64_t size = buffer_infos_[i].size(); - descriptor_table.emplace_back(data, size); - } - return descriptor_table; -} - XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { xla::cpu_function_runtime::FreeContiguous(alloc_buffer_table_); delete[] buffer_table_; diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index d03f06e14f5bce..db280e239f0441 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -61,15 +61,6 @@ class XlaCompiledCpuFunction { const void** args, void** temps, XlaCustomCallStatus*, int64_t* profile_counters); - // Signature of the XLA Runtime raw function. Used only by XLA Runtime AOT. - using XlaRuntimeRawFunction = void (*)(void**); - - // Signature of an external run function. Used only by XLA Runtime JIT. - using ExternalRunFunction = - bool (*)(const xla::cpu::CpuExecutable* cpu_executable, - const std::vector& descriptor_table, - const xla::ExecutableRunOptions* run_options); - // Simple struct to describe a tensor's shape. // Note: this is a poor man's substitute for xla::ShapeProto, but we cannot // depend on protobuf's in this library. @@ -90,9 +81,6 @@ class XlaCompiledCpuFunction { // The raw function to call. RawFunction raw_function_; - ExternalRunFunction external_run_function_ = nullptr; - const xla::cpu::CpuExecutable* cpu_executable_ = nullptr; - // Contains information about the buffers used by the XLA computation. const xla::cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr; int32_t num_buffers_ = 0; @@ -139,8 +127,6 @@ class XlaCompiledCpuFunction { // declared so we don't have access to that information here. int64_t profile_counters_size_ = 0; - bool use_xla_runtime_ = false; - // Only XlaCompiledCpuFunction is allowed to read and write the above // fields. friend class XlaCompiledCpuFunction; @@ -164,6 +150,8 @@ class XlaCompiledCpuFunction { XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; XlaCompiledCpuFunction& operator=(const XlaCompiledCpuFunction&) = delete; + XlaCompiledCpuFunction(XlaCompiledCpuFunction&&) = default; + XlaCompiledCpuFunction& operator=(XlaCompiledCpuFunction&&) = default; // Sets the intra-op thread pool used to run individual ops concurrently. void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { @@ -331,16 +319,6 @@ class XlaCompiledCpuFunction { static_data->raw_function_ = raw_function; } - static void set_static_data_external_run_function( - StaticData* static_data, ExternalRunFunction external_run_function) { - static_data->external_run_function_ = external_run_function; - } - - static void set_static_data_cpu_executable( - StaticData* static_data, const xla::cpu::CpuExecutable* cpu_executable) { - static_data->cpu_executable_ = cpu_executable; - } - static void set_static_data_buffer_infos( StaticData* static_data, const xla::cpu_function_runtime::BufferInfo* buffer_infos) { @@ -428,19 +406,13 @@ class XlaCompiledCpuFunction { static_data->profile_counters_size_ = profile_counters_size; } - static void set_static_data_use_xla_runtime(StaticData* static_data, - bool use_xla_runtime) { - static_data->use_xla_runtime_ = use_xla_runtime; - } + // TODO(ezhulenev): This is a no-op after removing xla runtime, however it is + // still required for building some targets. Figure out why and delete! + static void set_static_data_use_xla_runtime(StaticData* static_data, bool) {} private: const RawFunction raw_function_; - // [Optional] External Run() function. - const ExternalRunFunction external_run_function_; - // [Maybe Optional] CpuExecutable to be passed to external_run_function_. - const xla::cpu::CpuExecutable* cpu_executable_; - const size_t result_index_; // Array containing pointers to argument and temp buffers (slots corresponding @@ -488,13 +460,6 @@ class XlaCompiledCpuFunction { const xla::ProgramShapeProto* program_shape_ = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; - const bool use_xla_runtime_ = false; - - // Creates a descriptor table for XLA Runtime. - std::vector MakeXlaRuntimeDescriptorTable(); - - bool RunXlaRuntime(); - // Add `XlaJitCompiledCpuFunction` as a friend so that it can access the // `set_static_data_*` static methods above. friend class XlaJitCompiledCpuFunction; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 8af2c21994d4c4..b684d9b9df08ef 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/jit/xla_compile_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" #include "tensorflow/compiler/mlir/utils/array_container_utils.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -78,11 +79,18 @@ limitations under the License. #include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/util/debug_data_dumper.h" #include "tensorflow/core/util/dump_graph.h" +#include "tsl/platform/errors.h" #include "tsl/platform/tensor_float_32_utils.h" namespace tensorflow { namespace { +// Name of component for error logging. This name is fixed and required to +// enable logging. +constexpr char kSingleOpComponent[] = "TF2XLA_XLA_COMPILER_COMPILE_SINGLE_OP"; +constexpr char kCompileFunctionComponent[] = + "TF2XLA_XLA_COMPILER_COMPILE_FUNCTION"; + // Checks that arguments `args` match types `types`. Status CheckSignature(const DataTypeVector& types, absl::Span args) { @@ -769,6 +777,9 @@ Status XlaCompiler::CompileSingleOp( tensorflow::metrics::IncrementPhase2XlaCompilerCounter( tensorflow::metrics::Phase2XlaCompilerMetric:: kCompileSingleOpXlaBuilderFailure); + tsl::error_logging::Log(mlir::TF::kBridgeComponent, kSingleOpComponent, + status.ToString()) + .IgnoreError(); } return status; } @@ -778,7 +789,7 @@ Status XlaCompiler::CompileFunction( const NameAttrList& fn_name_attrs, absl::Span args, XlaCompiler::CompilationResult* result) { - const string function_id = + string function_id = Canonicalize(fn_name_attrs.name(), AttrSlice(&fn_name_attrs.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; @@ -861,49 +872,25 @@ Status XlaCompiler::CompileFunction( VLOG(1) << "===================================================="; - auto state = ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED; - if (options.is_entry_computation) { - state = GetMlirBridgeRolloutState(config_proto); - } - - if (state == ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { - GraphDebugInfo debug_info; - VLOG(1) << "Using the MLIR bridge to compile the function."; - std::vector valid_control_rets = - GetValidControlRets(fbody->control_ret_nodes, *graph); - auto mlir_result = CompileGraphToXlaHlo( - std::move(*graph), mlir::SpanToArrayRef(args), - valid_control_rets, options_.device_type.type_string(), - options.use_tuple_arg, /*analyse_graph=*/false, *options_.flib_def, - debug_info, options_.shape_determination_fns, result); - if (mlir_result.ok()) { - tensorflow::metrics::IncrementPhase2XlaCompilerCounter( - tensorflow::metrics::Phase2XlaCompilerMetric:: - kCompileFunctionMlirSuccess); - VLOG(1) << "MLIR bridge was successfull"; - } else { - tensorflow::metrics::IncrementPhase2XlaCompilerCounter( - tensorflow::metrics::Phase2XlaCompilerMetric:: - kCompileFunctionMlirFailure); - VLOG(1) << "MLIR failed, no fallback"; - return mlir_result; - } - } else { - VLOG(1) << "MLIR bridge off. Using the old bridge to compile the function"; - auto status = - CompileGraph(options, function_id, std::move(graph), args, result); - if (!status.ok()) { - tensorflow::metrics::IncrementPhase2XlaCompilerCounter( - tensorflow::metrics::Phase2XlaCompilerMetric:: - kCompileFunctionXlaBuilderFailure); - ::tsl::errors::AppendToMessage( - &status, "tf2xla conversion failed while converting ", function_id, - ". Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and " - "--vmodule=xla_compiler=2 to obtain a dump of the compiled " - "functions."); - return status; - } + VLOG(1) << "CompileFunction with XlaBuilder"; + auto status = + CompileGraph(options, function_id, std::move(graph), args, result); + if (!status.ok()) { + tensorflow::metrics::IncrementPhase2XlaCompilerCounter( + tensorflow::metrics::Phase2XlaCompilerMetric:: + kCompileFunctionXlaBuilderFailure); + ::tsl::errors::AppendToMessage( + &status, "tf2xla conversion failed while converting ", + std::move(function_id), + ". Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and " + "--vmodule=xla_compiler=2 to obtain a dump of the compiled " + "functions."); + tsl::error_logging::Log(mlir::TF::kBridgeComponent, + kCompileFunctionComponent, status.ToString()) + .IgnoreError(); + return status; } + tensorflow::metrics::IncrementPhase2XlaCompilerCounter( tensorflow::metrics::Phase2XlaCompilerMetric:: kCompileFunctionXlaBuilderSuccess); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index a83208f5fd0f5e..024d8cd469fb56 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/service/gpu/gpu_executable_run_options.h" -#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/stream_executor/stream.h" #include "xla/types.h" #include "tensorflow/core/common_runtime/device_mgr.h" diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 754d018cc5781c..566cc338706044 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -92,15 +92,6 @@ void CollectNames(const T& entries, std::vector* nonempty_names, name_ptrs->push_back(nullptr); // array terminator } -bool RunXlaRuntime(const xla::cpu::CpuExecutable* cpu_executable, - const std::vector& descriptor_table, - const xla::ExecutableRunOptions* run_options) { - assert(cpu_executable->IsXlaRuntime()); - Status status = - cpu_executable->ExecuteXlaRuntime(descriptor_table, run_options); - return status.ok(); -} - } // namespace /*static*/ absl::StatusOr> @@ -171,12 +162,6 @@ XlaJitCompiledCpuFunction::Compile( std::make_unique(program_shape->ToProto()); XlaCompiledCpuFunction::set_static_data_raw_function(&jit->static_data_, raw_function); - if (cpu_executable->IsXlaRuntime()) { - XlaCompiledCpuFunction::set_static_data_external_run_function( - &jit->static_data_, RunXlaRuntime); - XlaCompiledCpuFunction::set_static_data_cpu_executable(&jit->static_data_, - cpu_executable); - } XlaCompiledCpuFunction::set_static_data_buffer_infos( &jit->static_data_, jit->buffer_infos_.data()); XlaCompiledCpuFunction::set_static_data_num_buffers( diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 29e0de5edafbc2..e2adb15245c183 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1266,7 +1266,7 @@ cc_library( "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", "@ml_dtypes//:float8", - "@ml_dtypes//:int4", + "@ml_dtypes//:intn", ] + if_static([":lib_internal_impl"]), ) @@ -1294,7 +1294,7 @@ cc_library( "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", "@ml_dtypes//:float8", - "@ml_dtypes//:int4", + "@ml_dtypes//:intn", ], ) @@ -1443,7 +1443,7 @@ cc_library( "@eigen_archive//:eigen3", "@local_tsl//tsl/lib/math:math_util", "@ml_dtypes//:float8", - "@ml_dtypes//:int4", + "@ml_dtypes//:intn", "@snappy", "@zlib", ] + select({ diff --git a/tensorflow/core/api_def/base_api/api_def_ConvertToListOfSparseCoreCooTensors.pbtxt b/tensorflow/core/api_def/base_api/api_def_ConvertToListOfSparseCoreCooTensors.pbtxt new file mode 100644 index 00000000000000..13f09747d4025a --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ConvertToListOfSparseCoreCooTensors.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ConvertToListOfSparseCoreCooTensors" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ConvertToSparseCoreCsrWrappedCooTensor.pbtxt b/tensorflow/core/api_def/base_api/api_def_ConvertToSparseCoreCsrWrappedCooTensor.pbtxt new file mode 100644 index 00000000000000..8676be4f4f6c2f --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ConvertToSparseCoreCsrWrappedCooTensor.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ConvertToSparseCoreCsrWrappedCooTensor" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt index 7f2a8a1cf1ab33..da21e0f6981c7e 100644 --- a/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt @@ -55,6 +55,10 @@ Note that on CPU, if an out of bound index is found, an error is returned. On GPU, if an out of bound index is found, a 0 is stored in the corresponding output value. +Note that on TPU, if any dimension of `params` is of size 0 then the output will +be the expected shape filled with zeros. On CPU and GPU an error will be +returned. + See also `tf.batch_gather` and `tf.gather_nd`. END } diff --git a/tensorflow/core/api_def/base_api/api_def_GetStatsFromListOfSparseCoreCooTensors.pbtxt b/tensorflow/core/api_def/base_api/api_def_GetStatsFromListOfSparseCoreCooTensors.pbtxt new file mode 100644 index 00000000000000..e0976255bd776e --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_GetStatsFromListOfSparseCoreCooTensors.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "GetStatsFromListOfSparseCoreCooTensors" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_GetTpuTaskId.pbtxt b/tensorflow/core/api_def/base_api/api_def_GetTpuTaskId.pbtxt new file mode 100644 index 00000000000000..62e52142d9af92 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_GetTpuTaskId.pbtxt @@ -0,0 +1,14 @@ +op { + graph_op_name: "GetTpuTaskId" + visibility: HIDDEN + out_arg { + name: "tpu_task_id" + description: <