diff --git a/.bazelrc b/.bazelrc index 5bac396c287384..657f6b86246aae 100644 --- a/.bazelrc +++ b/.bazelrc @@ -632,8 +632,12 @@ try-import %workspace%/.bazelrc.user # Build TensorFlow v2. test:release_base --test_size_filters=small,medium +# Ensure release_base is set on linux +build:release_linux_base --config=release_base + # Target the AVX instruction set build:release_linux_base --config=avx_linux + # Enable support for all targets build:release_base --config=cpu_cross @@ -719,12 +723,14 @@ build:unsupported_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gc build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain build:release_cpu_macos --config=avx_linux -test:release_cpu_macos --config=release_base # Base build configs for macOS build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer build:release_macos_base --define=no_nccl_support=true --output_filter=^$ +# Ensure release_base is set on mac +build:release_macos_base --config=release_base + # Build configs for macOS x86 build:release_macos_x86 --config=release_macos_base # Build with the AVX instruction set when on macOS x86 @@ -754,10 +760,12 @@ test:release_macos_x86 --config=release_macos_base # Test configs for macOS Arm64 test:release_macos_arm64 --config=release_macos_base +# Ensure release_base is set on windows +build:release_cpu_windows --config=release_base + # TODO(kanglan): Update windows configs after b/289091160 is fixed build:release_cpu_windows --config=avx_win build:release_cpu_windows --define=no_tensorflow_py_deps=true -test:release_cpu_windows --config=release_base # Exclude TFRT integration for anything but Linux. build:android --config=no_tfrt @@ -962,3 +970,6 @@ build:rbe_cross_compile_macos_x86 --jobs=100 test:rbe_cross_compile_macos_x86 --jobs=100 # END MACOS CROSS-COMPILE CONFIGS # END CROSS-COMPILE CONFIGS + +# Try to load the XLA warnings config if available +try-import %workspace%/warnings.bazelrc diff --git a/.github/workflows/update-rbe.yml b/.github/workflows/update-rbe.yml index bdce23b94d02f1..d670cd6040401d 100644 --- a/.github/workflows/update-rbe.yml +++ b/.github/workflows/update-rbe.yml @@ -123,7 +123,7 @@ jobs: title: Update the RBE images to the latest container versions committer: TensorFlow Release Automation token: ${{ secrets.JENKINS_TOKEN }} - reviewers: angerson,mihaimaruseac,learning-to-play,nitins17 + reviewers: mihaimaruseac,learning-to-play,nitins17 body: | This PR was created by a GitHub Actions workflow to update all the SIG Build-based RBE containers to the most recent containers. See: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2b9b9f9304d142..89c61463462745 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -71,7 +71,7 @@ Before sending your pull requests, make sure you do the following: In a graphical form, the entire lifetime of a PR looks like -![image](https://user-images.githubusercontent.com/323199/229561784-0a2f5509-b731-493f-ad88-bad487688c8d.png) +![image](https://github.com/tensorflow/tensorflow/assets/52792999/3eea4ca5-daa0-4570-b0b5-2a2b03a724a3) ### Contributor License Agreements diff --git a/RELEASE.md b/RELEASE.md index 8089cf75521191..4ab21903418868 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -29,6 +29,8 @@ * GPU * Support for NVIDIA GPUs with compute capability 8.9 (e.g. L4 & L40) has been added to TF binary distributions (Python wheels). +* Replace `DebuggerOptions` of TensorFlow Quantizer, and migrate to + `DebuggerConfig` of StableHLO Quantizer. ## Keras @@ -70,6 +72,12 @@ schema globally in the converter and inference engine. The new behaviour can be disabled via experimental flag `converter._experimental_disable_per_channel_quantization_for_dense_layers = True`. + * C API: + * The experimental `TfLiteRegistrationExternal` type has been renamed as + `TfLiteOperator`, and likewise for the corresponding API functions. + * The Python TF Lite Interpreter bindings now have an option + `experimental_default_delegate_latest_features` to enable all default + delegate features. ## Thanks to our Contributors diff --git a/WORKSPACE b/WORKSPACE index a697405110e206..675a9481283514 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,3 +1,5 @@ +# buildifier: disable=load-on-top + workspace(name = "org_tensorflow") # We must initialize hermetic python first. @@ -23,7 +25,7 @@ load("@rules_python//python:repositories.bzl", "py_repositories") py_repositories() -load("@rules_python//python:repositories.bzl", "python_register_toolchains") +load("@rules_python//python:repositories.bzl", "python_register_toolchains") # buildifier: disable=same-origin-load load( "//tensorflow/tools/toolchains/python:python_repo.bzl", "python_repository", diff --git a/ci/official/envs/ci_default b/ci/official/envs/ci_default index 96d87423392541..7db6569b3dc075 100644 --- a/ci/official/envs/ci_default +++ b/ci/official/envs/ci_default @@ -64,5 +64,6 @@ TFCI_PYTHON_VERSION= TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_AUDIT_PLAT= TFCI_WHL_BAZEL_TEST_ENABLE= +TFCI_WHL_IMPORT_TEST_ENABLE=1 TFCI_WHL_SIZE_LIMIT= TFCI_WHL_SIZE_LIMIT_ENABLE= diff --git a/ci/official/envs/linux_x86_tpu b/ci/official/envs/linux_x86_tpu index 3c7d61b2ac3794..8fa88ad7c85902 100644 --- a/ci/official/envs/linux_x86_tpu +++ b/ci/official/envs/linux_x86_tpu @@ -18,5 +18,6 @@ TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu TFCI_BUILD_PIP_PACKAGE_ARGS="--repo_env=WHEEL_NAME=tensorflow_tpu" TFCI_LIB_SUFFIX="-tpu-linux-x86_64" TFCI_WHL_BAZEL_TEST_ENABLE=0 +TFCI_WHL_IMPORT_TEST_ENABLE=0 TFCI_WHL_SIZE_LIMIT=580M TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-tf-releases/index.html" diff --git a/ci/official/requirements_updater/BUILD.bazel b/ci/official/requirements_updater/BUILD.bazel index 8cdb70597f0a83..06a0898d9a2b78 100644 --- a/ci/official/requirements_updater/BUILD.bazel +++ b/ci/official/requirements_updater/BUILD.bazel @@ -13,10 +13,10 @@ # limitations under the License. # ============================================================================== -load("@python//3.9:defs.bzl", compile_pip_requirements_3_9 = "compile_pip_requirements") load("@python//3.10:defs.bzl", compile_pip_requirements_3_10 = "compile_pip_requirements") load("@python//3.11:defs.bzl", compile_pip_requirements_3_11 = "compile_pip_requirements") load("@python//3.12:defs.bzl", compile_pip_requirements_3_12 = "compile_pip_requirements") +load("@python//3.9:defs.bzl", compile_pip_requirements_3_9 = "compile_pip_requirements") load("@updater_config_repository//:updater_config_repository.bzl", "REQUIREMENTS_FILE_NAME") compile_pip_requirements_3_9( diff --git a/ci/official/requirements_updater/WORKSPACE b/ci/official/requirements_updater/WORKSPACE index 9b56cc0422bf6d..f9a116a6a3153e 100644 --- a/ci/official/requirements_updater/WORKSPACE +++ b/ci/official/requirements_updater/WORKSPACE @@ -1,3 +1,5 @@ +# buildifier: disable=load-on-top + workspace(name = "requirements_updater") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") @@ -22,7 +24,7 @@ load("@rules_python//python:repositories.bzl", "py_repositories") py_repositories() -load("@rules_python//python:repositories.bzl", "python_register_multi_toolchains") +load("@rules_python//python:repositories.bzl", "python_register_multi_toolchains") # buildifier: disable=same-origin-load load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependencies") default_python_version = "3.10" diff --git a/ci/official/utilities/rename_and_verify_wheels.sh b/ci/official/utilities/rename_and_verify_wheels.sh index 4f4ea6745d5cb9..a79ce2a8868a3e 100755 --- a/ci/official/utilities/rename_and_verify_wheels.sh +++ b/ci/official/utilities/rename_and_verify_wheels.sh @@ -58,8 +58,10 @@ venv=$(mktemp -d) "python${TFCI_PYTHON_VERSION}" -m venv "$venv" python="$venv/bin/python3" "$python" -m pip install *.whl $TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS -"$python" -c 'import tensorflow as tf; t1=tf.constant([1,2,3,4]); t2=tf.constant([5,6,7,8]); print(tf.add(t1,t2).shape)' -"$python" -c 'import sys; import tensorflow as tf; sys.exit(0 if "keras" in tf.keras.__name__ else 1)' +if [[ "$TFCI_WHL_IMPORT_TEST_ENABLE" == "1" ]]; then + "$python" -c 'import tensorflow as tf; t1=tf.constant([1,2,3,4]); t2=tf.constant([5,6,7,8]); print(tf.add(t1,t2).shape)' + "$python" -c 'import sys; import tensorflow as tf; sys.exit(0 if "keras" in tf.keras.__name__ else 1)' +fi # VERY basic check to ensure the [and-cuda] package variant is installable. # Checks TFCI_BAZEL_COMMON_ARGS for "gpu" or "cuda", implying that the test is # relevant. All of the GPU test machines have CUDA installed via other means, diff --git a/ci/official/wheel_test/WORKSPACE b/ci/official/wheel_test/WORKSPACE index cef9033d30120f..d52a3ed895173b 100644 --- a/ci/official/wheel_test/WORKSPACE +++ b/ci/official/wheel_test/WORKSPACE @@ -1,3 +1,5 @@ +# buildifier: disable=load-on-top + workspace(name = "wheel_test") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") @@ -38,7 +40,7 @@ python_repository(name = "python_version_repo") load("@python_version_repo//:py_version.bzl", "TF_PYTHON_VERSION") # Register multi toolchains -load("@rules_python//python:repositories.bzl", "python_register_toolchains") +load("@rules_python//python:repositories.bzl", "python_register_toolchains") # buildifier: disable=same-origin-load python_register_toolchains( name = "python", diff --git a/configure.py b/configure.py index c1cb20162012f6..66427431b42c16 100644 --- a/configure.py +++ b/configure.py @@ -759,7 +759,7 @@ def get_ndk_api_level(environ_cp, android_ndk_home_path): android_ndk_api_level = prompt_loop_or_load_from_env( environ_cp, var_name='ANDROID_NDK_API_LEVEL', - var_default='26', # 26 is required to support AHardwareBuffer. + var_default='21', # 21 is required for ARM64 support. ask_for_var=( 'Please specify the (min) Android NDK API level to use. ' '[Available levels: %s]' @@ -807,6 +807,18 @@ def choose_compiler(environ_cp): return var +def choose_compiler_Win(environ_cp): + question = 'Do you want to use Clang to build TensorFlow?' + yes_reply = 'Add "--config=win_clang" to compile TensorFlow with CLANG.' + no_reply = 'MSVC will be used to compile TensorFlow.' + var = int( + get_var( + environ_cp, 'TF_NEED_CLANG', None, True, question, yes_reply, no_reply + ) + ) + return var + + def set_clang_compiler_path(environ_cp): """Set CLANG_COMPILER_PATH and environment variables. @@ -848,6 +860,44 @@ def set_clang_compiler_path(environ_cp): return clang_compiler_path +def set_clang_compiler_path_win(environ_cp): + """Set CLANG_COMPILER_PATH and environment variables. + + Loop over user prompts for clang path until receiving a valid response. + Default is used if no input is given. Set CLANG_COMPILER_PATH and write + environment variables CC and BAZEL_COMPILER to .bazelrc. + + Args: + environ_cp: (Dict) copy of the os.environ. + + Returns: + string value for clang_compiler_path. + """ + # Default path if clang-16 is installed by using apt-get install + default_clang_path = 'C:/Program Files/LLVM/bin/clang.exe' + if not os.path.exists(default_clang_path): + default_clang_path = which('clang') or '' + + clang_compiler_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='CLANG_COMPILER_PATH', + var_default=default_clang_path, + ask_for_var='Please specify the path to clang executable.', + check_success=os.path.exists, + resolve_symlinks=True, + error_msg=( + 'Invalid clang path. %s cannot be found. Note that Clang is now' + 'preferred compiler. You may use MSVC by removing --config=win_clang' + ), + ) + + write_action_env_to_bazelrc('CLANG_COMPILER_PATH', clang_compiler_path) + write_to_bazelrc('build --repo_env=CC=%s' % clang_compiler_path) + write_to_bazelrc('build --repo_env=BAZEL_COMPILER=%s' % clang_compiler_path) + + return clang_compiler_path + + def retrieve_clang_version(clang_executable): """Retrieve installed clang version. @@ -1386,8 +1436,9 @@ def main(): else: raise UserInputError( 'Invalid CUDA setting were provided %d ' - 'times in a row. Assuming to be a scripting mistake.' % - _DEFAULT_PROMPT_ASK_ATTEMPTS) + 'times in a row. Assuming to be a scripting mistake.' + % _DEFAULT_PROMPT_ASK_ATTEMPTS + ) set_tf_cuda_compute_capabilities(environ_cp) if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get( @@ -1415,6 +1466,12 @@ def main(): clang_compiler_path = set_clang_compiler_path(environ_cp) clang_version = retrieve_clang_version(clang_compiler_path) disable_clang_offsetof_extension(clang_version) + if is_windows(): + environ_cp['TF_NEED_CLANG'] = str(choose_compiler_Win(environ_cp)) + if environ_cp.get('TF_NEED_CLANG') == '1': + clang_compiler_path = set_clang_compiler_path_win(environ_cp) + clang_version = retrieve_clang_version(clang_compiler_path) + disable_clang_offsetof_extension(clang_version) # ROCm / CUDA are mutually exclusive. # At most 1 GPU platform can be configured. diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index eff34ce6a3f320..47c751c54b15a5 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -2,6 +2,8 @@ # C API for TensorFlow, for use by client language bindings. load("@bazel_skylib//lib:selects.bzl", "selects") +load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "check_deps", @@ -18,8 +20,6 @@ load( "//tensorflow/core/tpu:build_defs.bzl", "if_libtpu_tf_status", ) -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD index 647937fe98e47d..54afc6f757d740 100644 --- a/tensorflow/c/eager/parallel_device/BUILD +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup") -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", ) +load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -56,9 +56,7 @@ cc_library( name = "parallel_device", srcs = [":device_sources"], hdrs = [":device_headers"], - visibility = [ - "//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private. - ], + visibility = ["//visibility:private"], deps = [ ":parallel_device_lib", "//tensorflow/c:c_api", diff --git a/tensorflow/c/experimental/filesystem/BUILD b/tensorflow/c/experimental/filesystem/BUILD index e55d95334cf6cf..d25e6e9314f088 100644 --- a/tensorflow/c/experimental/filesystem/BUILD +++ b/tensorflow/c/experimental/filesystem/BUILD @@ -1,9 +1,8 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") - # Experimental filesystem C APIs for TensorFlow. # Will be moved in proper place once all filesystems are converted to the # modular framework. load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD index bd2041b1d43957..7c23cb79143b01 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD @@ -1,7 +1,6 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") - # Experimental gcs filesystem plugin. load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/c/experimental/filesystem/plugins/posix/BUILD b/tensorflow/c/experimental/filesystem/plugins/posix/BUILD index 90acb2bf389370..a4406b46945193 100644 --- a/tensorflow/c/experimental/filesystem/plugins/posix/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/posix/BUILD @@ -1,7 +1,6 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") - # Experimental posix filesystem plugin. load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -25,6 +24,7 @@ cc_library( hdrs = ["posix_filesystem.h"], deps = [ ":posix_filesystem_helper", + "//tensorflow/c:tf_file_statistics", "//tensorflow/c:tf_status", "//tensorflow/c/experimental/filesystem:filesystem_interface", ], @@ -40,6 +40,8 @@ cc_library( ":posix_filesystem_impl", "//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:modular_filesystem", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/log", ], alwayslink = 1, ) diff --git a/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.cc index f1f3dda5e8ccc0..e3fbf03ea19440 100644 --- a/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.cc @@ -26,7 +26,9 @@ limitations under the License. #include #include +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h" +#include "tensorflow/c/tf_file_statistics.h" #include "tensorflow/c/tf_status.h" // Implementation of a filesystem for POSIX environments. diff --git a/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_static.cc b/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_static.cc index 6081722e699e86..60205858499aed 100644 --- a/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_static.cc +++ b/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_static.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/log/log.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h" #include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/filesystem/plugins/windows/BUILD b/tensorflow/c/experimental/filesystem/plugins/windows/BUILD index 2ac57f6a731344..159e36e485e6a6 100644 --- a/tensorflow/c/experimental/filesystem/plugins/windows/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/windows/BUILD @@ -1,7 +1,6 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") - # Experimental windows filesystem plugin. load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 65f580deee93c4..a3fa49fffa34b7 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -1,14 +1,14 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup") -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "if_libtpu", "tf_cuda_cc_test", ) +load("//tensorflow:tensorflow.default.bzl", "filegroup") load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") # Library of gradient functions. package( @@ -55,9 +55,7 @@ cc_library( hdrs = [ "nn_grad.h", ], - visibility = [ - "//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private. - ], + visibility = ["//visibility:private"], deps = [ "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:gradients_internal", @@ -148,9 +146,7 @@ cc_library( testonly = True, srcs = ["grad_test_helper.cc"], hdrs = ["grad_test_helper.h"], - visibility = [ - "//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private. - ], + visibility = ["//visibility:private"], deps = [ "//tensorflow/c/eager:gradient_checker", "//tensorflow/c/eager:gradients_internal", diff --git a/tensorflow/c/experimental/grappler/BUILD b/tensorflow/c/experimental/grappler/BUILD index fd26096fd5d871..d4892b1b9b9624 100644 --- a/tensorflow/c/experimental/grappler/BUILD +++ b/tensorflow/c/experimental/grappler/BUILD @@ -1,11 +1,11 @@ # Description: # Graph C API. -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD index 03c83a4e8f99e0..3d92b7ad3d2992 100644 --- a/tensorflow/c/experimental/next_pluggable_device/BUILD +++ b/tensorflow/c/experimental/next_pluggable_device/BUILD @@ -94,9 +94,9 @@ tf_cc_test( "@local_xla//xla:shape_util", "@local_xla//xla/pjrt:pjrt_api", "@local_xla//xla/pjrt:pjrt_c_api_client", - "@local_xla//xla/pjrt:tfrt_cpu_pjrt_client", "@local_xla//xla/pjrt/c:pjrt_c_api_cpu", "@local_xla//xla/pjrt/c:pjrt_c_api_hdrs", "@local_xla//xla/pjrt/c:pjrt_c_api_wrapper_impl", + "@local_xla//xla/pjrt/cpu:cpu_client", ], ) diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc index c72f0cfafa6ead..7f45fd91a1baea 100644 --- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc +++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc @@ -25,9 +25,9 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_cpu.h" #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" +#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD b/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD index 7589ea2d2f24a2..c13bc899f2d016 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD @@ -20,13 +20,9 @@ cc_library( deps = [ "//tensorflow/c/experimental/ops/gen/common", "//tensorflow/c/experimental/ops/gen/cpp/views", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:op_gen_lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:str_util", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc index 36c25c92760872..1fc16e093c011d 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h" +#include "absl/strings/str_split.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.cc index 44f23ae0fb6aed..71132cfc3bf8b2 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.cc @@ -14,8 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.h" -#include "tensorflow/c/experimental/ops/gen/common/view_util.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc index 8bfd5a334c565d..7a4275b532eda7 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h" #include "tensorflow/c/experimental/ops/gen/common/case_format.h" -#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h index cfe2a99acfddce..a45fe89a7a011c 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_GUARD_RENDERER_H_ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc index 5242d6f1baf255..38f31209f6da24 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc @@ -14,7 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h" -#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h index b98547079f3ac7..e43715a62e45b0 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_INCLUDE_RENDERER_H_ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc index 5547ca22df7ab0..db28ab303ae5c6 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.h" -#include "absl/strings/str_split.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.h index a54fc5878a0ad4..fd8ccf9531ef51 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.cc index e5afb7b6d63393..5d11bcada6e8c0 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h index 1d85c4c9fd7940..9131cc945349af 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_OP_COMMENT_RENDERER_H_ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.cc index e2184fcc7f834f..804e0585f88cca 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/experimental/ops/gen/common/view_util.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/arg_view.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/attr_view.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.h index 9237eb9410bad7..98c3b0d75524aa 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_OP_IMPLEMENTATION_RENDERER_H_ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc index 41db2ced426b47..c58e67782dfc34 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc @@ -16,7 +16,15 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/substitute.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/c/experimental/ops/gen/cpp/views/op_argument_view.h" +#include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h index c29fb35b5b6b7c..3360e14e672e3a 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc index 0e6ee460512d2d..41d1dea64b3689 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc @@ -14,9 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stringpiece.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h index b0a95baefa7676..b6168b196b35b2 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_RENDERER_H_ #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_RENDERER_H_ +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc index 2674e5f156d9d5..eff654c5938160 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc @@ -14,8 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/common/path_config.h" +#include "tensorflow/c/experimental/ops/gen/common/source_code.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/pluggable_profiler/BUILD b/tensorflow/c/experimental/pluggable_profiler/BUILD index a34faa3146735b..49bb842e2e6258 100644 --- a/tensorflow/c/experimental/pluggable_profiler/BUILD +++ b/tensorflow/c/experimental/pluggable_profiler/BUILD @@ -1,8 +1,8 @@ # Description: # Profiler C API -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index af37ab0cb19011..9ad56fcf6671b2 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -1,5 +1,3 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") - # Experimental SavedModel C APIs for TensorFlow. See RFC # https://github.com/tensorflow/community/pull/207 # Targets in this directory are pure C++ "Classes" underlying the C API types @@ -9,6 +7,7 @@ load( "//tensorflow:tensorflow.bzl", "tf_cc_test", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/c/experimental/saved_model/core/ops/BUILD b/tensorflow/c/experimental/saved_model/core/ops/BUILD index cce725db3fcba1..3e9d28ed8795d4 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/BUILD +++ b/tensorflow/c/experimental/saved_model/core/ops/BUILD @@ -1,11 +1,10 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") - # This package contains written convenience helpers for Eager Operations # used by SavedModel. Once we autogenerate C++ Eager Op wrappers, we can remove these. load( "//tensorflow:tensorflow.bzl", "tf_cc_test", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 36e5cb52d2ec25..244bbc9e515f19 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -9,8 +9,6 @@ # Note(bmzhao): The *.cc files in this directory form the direct implementation of the # C API functions exposed in tf/c/experimental/saved_model/public/. -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") - # Note(bmzhao): All *type.h files in this directory are the internal definitions of # the opaque C types. These headers should only be visible to internal tensorflow # implementors. @@ -19,6 +17,7 @@ load( "tf_cc_test", "tf_copts", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -162,9 +161,7 @@ cc_library( hdrs = [ "saved_model_api_type.h", ], - visibility = [ - "//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private. - ], + visibility = ["//visibility:private"], deps = [ "//tensorflow/c:conversion_macros", "//tensorflow/c/experimental/saved_model/core:saved_model_api", diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/BUILD b/tensorflow/c/experimental/saved_model/internal/testdata/BUILD index a10cfd03e3dc86..ec36b292a6518e 100644 --- a/tensorflow/c/experimental/saved_model/internal/testdata/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/testdata/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow:strict.default.bzl", "py_strict_binary") +load("//tensorflow:tensorflow.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/c/kernels/BUILD b/tensorflow/c/kernels/BUILD index 8e38201c2a5960..7bcaa66060665b 100644 --- a/tensorflow/c/kernels/BUILD +++ b/tensorflow/c/kernels/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup", "tf_gen_op_libs", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "filegroup", "tf_gen_op_libs", "tf_kernel_library") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index a989862e4f79fb..43f63b2ba0cb81 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -2,7 +2,6 @@ # TensorFlow is a computational framework, primarily for use in machine # learning applications. -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "cc_library_with_android_deps", @@ -11,6 +10,7 @@ load( "transitive_hdrs", ) load("//tensorflow:tensorflow.default.bzl", "filegroup", "tf_gen_op_wrappers_cc") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/cc/experimental/base/tests/BUILD b/tensorflow/cc/experimental/base/tests/BUILD index e749d2433bd696..70184355fe76aa 100644 --- a/tensorflow/cc/experimental/base/tests/BUILD +++ b/tensorflow/cc/experimental/base/tests/BUILD @@ -1,7 +1,6 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") - # Tests for the C++ header-only base types. load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/cc/experimental/libexport/BUILD b/tensorflow/cc/experimental/libexport/BUILD index 910ab930440f68..d206c115abea65 100644 --- a/tensorflow/cc/experimental/libexport/BUILD +++ b/tensorflow/cc/experimental/libexport/BUILD @@ -1,8 +1,8 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/cc/experimental/libtf/BUILD b/tensorflow/cc/experimental/libtf/BUILD index 379f2e430aaacd..2c67800eacff63 100644 --- a/tensorflow/cc/experimental/libtf/BUILD +++ b/tensorflow/cc/experimental/libtf/BUILD @@ -1,16 +1,16 @@ #include "third_party/absl/strings/str_cat.h" #TODO(aselle) : describe this package. -load( - "//tensorflow/core/platform:rules_cc.bzl", - "cc_library", -) +load("//tensorflow:strict.default.bzl", "py_strict_binary") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", ) load("//tensorflow:tensorflow.default.bzl", "filegroup") -load("//tensorflow:strict.default.bzl", "py_strict_binary") +load( + "//tensorflow/core/platform:rules_cc.bzl", + "cc_library", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/cc/experimental/libtf/impl/BUILD b/tensorflow/cc/experimental/libtf/impl/BUILD index 0eae5a1f05c133..4f5b7ccfd84940 100644 --- a/tensorflow/cc/experimental/libtf/impl/BUILD +++ b/tensorflow/cc/experimental/libtf/impl/BUILD @@ -1,13 +1,13 @@ # libtf implementation details. -load( - "//tensorflow/core/platform:rules_cc.bzl", - "cc_library", -) load( "//tensorflow:tensorflow.bzl", "tf_cc_test", ) +load( + "//tensorflow/core/platform:rules_cc.bzl", + "cc_library", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/cc/framework/fuzzing/BUILD b/tensorflow/cc/framework/fuzzing/BUILD index ec424fc0425630..74a946c283777d 100644 --- a/tensorflow/cc/framework/fuzzing/BUILD +++ b/tensorflow/cc/framework/fuzzing/BUILD @@ -1,11 +1,11 @@ # TODO(unda): describe this package. load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("//tensorflow:tensorflow.bzl", "tf_copts") load( "//tensorflow/cc/framework/fuzzing:op_fuzzing.bzl", "tf_gen_op_wrappers_fuzz", ) -load("//tensorflow:tensorflow.bzl", "tf_copts") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index a7a0af29268459..6cc731e722d16b 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -501,6 +501,7 @@ cc_library( "//tensorflow/core/graph/regularization:util", "//tensorflow/core/util/tensor_bundle:naming", "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/tensorflow/cc/saved_model/fingerprinting.cc b/tensorflow/cc/saved_model/fingerprinting.cc index a98980d3c2760a..cf2ae4721623fa 100644 --- a/tensorflow/cc/saved_model/fingerprinting.cc +++ b/tensorflow/cc/saved_model/fingerprinting.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/container/btree_map.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -68,6 +69,7 @@ uint64_t HashCheckpointIndexFile(absl::string_view model_dir) { if (read_status.ok()) { return tensorflow::Fingerprint64(data); } else { + LOG(WARNING) << "Failed to read checkpoint file: " << read_status; return 0; } } @@ -209,8 +211,7 @@ absl::StatusOr ReadSavedModelFingerprint( absl::string_view export_dir) { const std::string fingerprint_pb_path = io::JoinPath(export_dir, kFingerprintFilenamePb); - absl::Status found_pb = Env::Default()->FileExists(fingerprint_pb_path); - if (!found_pb.ok()) return found_pb; + TF_RETURN_IF_ERROR(Env::Default()->FileExists(fingerprint_pb_path)); FingerprintDef fingerprint_proto; absl::Status result = diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index ae63fdab2fa32c..18fd6655fd269d 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -18,9 +18,11 @@ limitations under the License. #include #include #include +#include #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/cc/saved_model/fingerprinting.h" #include "tensorflow/cc/saved_model/loader_util.h" @@ -280,6 +282,16 @@ Status LoadMetagraphIntoSession(const SessionOptions& session_options, return (*session)->Create(meta_graph.graph_def()); } +Status LoadGraphDefIntoSession(const SessionOptions& session_options, + GraphDef graph_def, + std::unique_ptr* session) { + Session* session_p = nullptr; + TF_RETURN_IF_ERROR(NewSession(session_options, &session_p)); + session->reset(session_p); + TF_RETURN_IF_ERROR(ValidateSavedTensors(graph_def)); + return (*session)->Create(std::move(graph_def)); +} + Status LoadSavedModelInternal(const SessionOptions& session_options, const RunOptions& run_options, const string& export_dir, @@ -296,40 +308,6 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, return absl::OkStatus(); } -Status LoadSavedModel(const SessionOptions& session_options, - const RunOptions& run_options, const string& export_dir, - const std::unordered_set& tags, - SavedModelBundle* const bundle) { - metrics::SavedModelReadApi(kCCLoadLabel).IncrementBy(1); - auto fingerprint_proto = - saved_model::fingerprinting::ReadSavedModelFingerprint(export_dir); - if (fingerprint_proto.ok()) { - // Set gauge cell with saved_model_checksum. - metrics::SavedModelReadFingerprint().Set( - std::to_string(fingerprint_proto->saved_model_checksum())); - } - - // TODO(robson): Add tests for the counters. - const uint64 start_microseconds = Env::Default()->NowMicros(); - const Status status = LoadSavedModelInternal(session_options, run_options, - export_dir, tags, bundle); - auto log_and_count = [&](const string& status_str) { - LOG(INFO) << "SavedModel load for tags { " << absl::StrJoin(tags, " ") - << " }; Status: " << status_str << ": " << status << ". Took " - << GetLatencyMicroseconds(start_microseconds) << " microseconds."; - load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); - }; - if (status.ok()) { - log_and_count(kLoadAttemptSuccess); - metrics::SavedModelReadPath().Set(export_dir); - } else { - log_and_count(kLoadAttemptFail); - } - load_latency->GetCell(export_dir) - ->IncrementBy(GetLatencyMicroseconds(start_microseconds)); - return status; -} - namespace { // Session wrapper that prevents calls to Session::Create(), Session::Extend(), // and the deprecated partial-run methods. @@ -441,6 +419,70 @@ class LiteSessionWrapper : public Session { }; } // namespace +Status LoadSavedModelInternal(const SessionOptions& session_options, + const RunOptions& run_options, + const string& export_dir, + const std::unordered_set& tags, + SavedModelBundleLite* const bundle) { + MetaGraphDef meta_graph_def; + TF_RETURN_IF_ERROR( + ReadMetaGraphDefFromSavedModel(export_dir, tags, &meta_graph_def)); + std::unique_ptr session; + TF_RETURN_IF_ERROR(LoadGraphDefIntoSession( + session_options, std::move(*meta_graph_def.mutable_graph_def()), + &session)); + TF_RETURN_IF_ERROR( + RestoreSession(run_options, meta_graph_def, export_dir, &session)); + *bundle = SavedModelBundleLite( + std::make_unique(std::move(session)), + std::move(*meta_graph_def.mutable_signature_def())); + return absl::OkStatus(); +} + +template +Status LoadSavedModelGeneric(const SessionOptions& session_options, + const RunOptions& run_options, + const string& export_dir, + const std::unordered_set& tags, + BundleType* const bundle) { + metrics::SavedModelReadApi(kCCLoadLabel).IncrementBy(1); + auto fingerprint_proto = + saved_model::fingerprinting::ReadSavedModelFingerprint(export_dir); + if (fingerprint_proto.ok()) { + // Set gauge cell with saved_model_checksum. + metrics::SavedModelReadFingerprint().Set( + std::to_string(fingerprint_proto->saved_model_checksum())); + } + + // TODO(robson): Add tests for the counters. + const uint64 start_microseconds = Env::Default()->NowMicros(); + const Status status = LoadSavedModelInternal(session_options, run_options, + export_dir, tags, bundle); + auto log_and_count = [&](const string& status_str) { + LOG(INFO) << "SavedModel load for tags { " << absl::StrJoin(tags, " ") + << " }; Status: " << status_str << ": " << status << ". Took " + << GetLatencyMicroseconds(start_microseconds) << " microseconds."; + load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); + }; + if (status.ok()) { + log_and_count(kLoadAttemptSuccess); + metrics::SavedModelReadPath().Set(export_dir); + } else { + log_and_count(kLoadAttemptFail); + } + load_latency->GetCell(export_dir) + ->IncrementBy(GetLatencyMicroseconds(start_microseconds)); + return status; +} + +Status LoadSavedModel(const SessionOptions& session_options, + const RunOptions& run_options, const string& export_dir, + const std::unordered_set& tags, + SavedModelBundle* const bundle) { + return LoadSavedModelGeneric(session_options, run_options, + export_dir, tags, bundle); +} + Status RestoreSession(const RunOptions& run_options, const MetaGraphDef& meta_graph, const string& export_dir, std::unique_ptr* session) { @@ -476,7 +518,6 @@ Status LoadSavedModel(const SessionOptions& session_options, const RunOptions& run_options, const string& export_dir, const std::unordered_set& tags, SavedModelBundleLite* const bundle) { - SavedModelBundle legacy_bundle; SessionOptions rewritten_options(session_options); // We disallow calls to Session::Extend() on the returned session, so we can // reduce memory consumption by not storing the original GraphDef. @@ -489,11 +530,8 @@ Status LoadSavedModel(const SessionOptions& session_options, ->set_disable_output_partition_graphs(true); // TODO(mrry): Consider specializing the session creation to reduce peak // RAM consumption by using `Session::Create(GraphDef&&)`. - TF_RETURN_IF_ERROR(LoadSavedModel(rewritten_options, run_options, export_dir, - tags, &legacy_bundle)); - *bundle = SavedModelBundleLite( - std::make_unique(std::move(legacy_bundle.session)), - std::move(*legacy_bundle.meta_graph_def.mutable_signature_def())); + TF_RETURN_IF_ERROR(LoadSavedModelGeneric(rewritten_options, run_options, + export_dir, tags, bundle)); return absl::OkStatus(); } diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD index bb5daa99742944..10601308ac7d0f 100644 --- a/tensorflow/cc/tools/BUILD +++ b/tensorflow/cc/tools/BUILD @@ -2,11 +2,11 @@ #Description: # TensorFlow cc tools. -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 92d62b34be8bf9..dfedc5a4f8c6c0 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -1,8 +1,8 @@ load("//tensorflow:strict.default.bzl", "py_strict_binary") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "filegroup", "genrule") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 8ebeae499bd177..76f3c147903748 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1,6 +1,6 @@ +load("@local_xla//xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm") load("//tensorflow:tensorflow.bzl", "if_libtpu", "if_with_tpu_support", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_cuda_only_cc_test") load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "tf_custom_op_py_strict_library") -load("@local_xla//xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm") load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library") load( "//tensorflow/core/platform:build_config_root.bzl", @@ -520,10 +520,7 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_xla//xla:status_macros", "@local_xla//xla/pjrt:pjrt_client", ], ) @@ -585,6 +582,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime:dma_helper", + "//tensorflow/core/common_runtime/gpu:gpu_serving_device_selector", "//tensorflow/core/tfrt/common:async_value_tensor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/cleanup", @@ -592,6 +590,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@local_tsl//tsl/framework:device_id_utils", + "@local_tsl//tsl/framework:serving_device_selector_policies", "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", "@local_xla//xla/client:local_client", diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index 66d9960ae0a62f..6372b2e5516cd3 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -1,6 +1,6 @@ load("//tensorflow:strict.default.bzl", "py_strict_library") -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/jit/tests/BUILD b/tensorflow/compiler/jit/tests/BUILD index e9880013bf2611..0c3c4986e44a88 100644 --- a/tensorflow/compiler/jit/tests/BUILD +++ b/tensorflow/compiler/jit/tests/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") licenses(["notice"]) diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 3f0a4847c54540..f9657509623cc1 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/op.h" @@ -58,6 +59,7 @@ limitations under the License. #include "tensorflow/core/tfrt/common/async_value_tensor.h" #include "tensorflow/core/util/stream_executor_util.h" #include "tsl/framework/device_id_utils.h" +#include "tsl/framework/serving_device_selector_policies.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -863,12 +865,35 @@ Status RunPjRtExecutable( TF_ASSIGN_OR_RETURN(xla::PjRtDevice * device, pjrt_client->LookupAddressableDevice(pjrt_device_id)); + gpu::GpuServingDeviceSelectorResource* device_selector_resource = nullptr; + if (device_type == DEVICE_GPU) { + auto rm = ctx->resource_manager(); + TF_RETURN_IF_ERROR(rm->LookupOrCreate< + gpu::GpuServingDeviceSelectorResource>( + rm->default_container(), gpu::kGpuServingDeviceSelectorResourceName, + &device_selector_resource, + [&](gpu::GpuServingDeviceSelectorResource** device_selector_resource) { + *device_selector_resource = new gpu::GpuServingDeviceSelectorResource( + pjrt_client->addressable_device_count(), + std::make_unique()); + return absl::OkStatus(); + })); + core::ScopedUnref device_selector_resource_ref(device_selector_resource); + + TF_ASSIGN_OR_RETURN(absl::string_view fingerprint, + executable->FingerprintExecutable()); + device_selector_resource->selector()->Enqueue(pjrt_device_id, fingerprint); + } TF_ASSIGN_OR_RETURN( std::vector> execute_outputs, RunPjRtExecutable(num_missing_prefix_ctx_inputs, inputs, variable_snapshots, updated_variables, device_type, use_pjrt_tensor_buffer, compilation_result, device, pjrt_client, executable)); + if (device_selector_resource != nullptr) { + device_selector_resource->selector()->Completed(pjrt_device_id, + /*had_error=*/false); + } TF_RETURN_IF_ERROR(PopulateCtxOutputsFromPjRtExecutableOutputs( num_missing_prefix_ctx_inputs, inputs, updated_variables, diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index b30f08a1bfe1b4..d0286e5acff9ce 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -59,6 +59,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", # buildcleaner:keep "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops", "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:runtime_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms/sparsecore:sparsecore_passes", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", "//tensorflow/compiler/mlir/tf2xla/internal/passes:clustering_passes", "//tensorflow/compiler/mlir/tf2xla/internal/passes:mlir_to_graph_passes", @@ -69,6 +70,7 @@ cc_library( "//tensorflow/compiler/mlir/tosa:tfl_passes", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir/framework/ir:xla_framework", "@local_xla//xla/mlir/framework/transforms:passes", diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 95b9e92fa4c97f..c3826f1bfb935c 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -349,7 +349,6 @@ cc_library( "transforms/passes.h", "utils/attribute_utils.h", "utils/utils.h", - "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", ], deps = [ ":converter_inc", @@ -382,8 +381,10 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", ], @@ -703,6 +704,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", "//tensorflow/compiler/mlir/lite/stablehlo:tfl_legalize_hlo", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow", @@ -738,11 +740,14 @@ cc_library( "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla:status", "@local_xla//xla:statusor", "@local_xla//xla/mlir_hlo", + "@stablehlo//:stablehlo_ops", ], ) @@ -854,12 +859,16 @@ cc_library( ], deps = [ "convert_type", + ":op_quant_spec_getters_inc", ":tensorflow_lite", ":tensorflow_lite_passes_inc_gen", + ":tensorflow_lite_post_quantize_inc_gen", + ":tensorflow_lite_quantize_inc_gen", ":validators", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std", "//tensorflow/compiler/mlir/quantization/common:uniform_quantized_types", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", @@ -878,6 +887,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@local_xla//xla/mlir_hlo", + "@stablehlo//:stablehlo_ops", ], ) @@ -1146,6 +1156,7 @@ cc_library( ":size_utils", ":tensorflow_lite", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_to_vhlo_pass", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/tensorflow", @@ -1373,6 +1384,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_quantization_passes", "//tensorflow/compiler/mlir/lite/stablehlo:compose_uniform_quantized_type_pass", + "//tensorflow/compiler/mlir/lite/stablehlo:composite_lowering", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", "//tensorflow/compiler/mlir/lite/stablehlo:rename_entrypoint_to_main", "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", # buildcleaner: keep @@ -1411,9 +1423,9 @@ cc_library( "//tensorflow/compiler/mlir/lite/metrics:error_collector", "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", "//tensorflow/compiler/mlir/lite/quantization/stablehlo:quantization", + "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_to_vhlo_pass", "//tensorflow/compiler/mlir/lite/stablehlo:op_stat_pass", - "//tensorflow/compiler/mlir/lite/stablehlo:stablehlo_tfl", "//tensorflow/compiler/mlir/lite/stablehlo:stablehlo_util", "//tensorflow/compiler/mlir/lite/stablehlo:transforms", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index dd4f59ebe3a889..69ec0bbbcee3dc 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -99,6 +99,9 @@ struct PassConfig { // When set to true, StableHLO Quantizer is run. The full configuration for // the quantizer is at `TocoFlags::quantization_config`. bool enable_stablehlo_quantizer = false; + + // Enables the attempt to directly lower composites into tflite ops. + bool enable_composite_direct_lowering = false; }; inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD index 21bf8f739aea78..248a55c7fe17e1 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:strict.default.bzl", "py_strict_library") load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load( "@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", ) +load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") @@ -82,9 +82,12 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeOpInterface", + "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SideEffectInterfaces", ], ) @@ -197,6 +200,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -215,6 +219,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/experimental/tac/hardwares:target_hardware", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -259,10 +264,13 @@ cc_library( "@com_google_protobuf//:protobuf_headers", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], alwayslink = 1, diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/examples/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/examples/BUILD index 57fb5ea9eef10d..c5707a5f888885 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/examples/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/examples/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD index 7ee0b43b84d98c..5b3f2836feec99 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow:tensorflow.default.bzl", "pybind_extension") load("//tensorflow:tensorflow.bzl", "VERSION") +load("//tensorflow:tensorflow.default.bzl", "pybind_extension") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 1a9ff8016649ef..b98d3220ee15a8 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -181,6 +181,8 @@ static StatusOr GetTFLiteType(Type type, return tflite::TensorType_FLOAT32; } else if (type.isF16()) { return tflite::TensorType_FLOAT16; + } else if (type.isBF16()) { + return tflite::TensorType_BFLOAT16; } else if (type.isF64()) { return tflite::TensorType_FLOAT64; } else if (type.isa()) { @@ -577,9 +579,6 @@ class Translator { module.getContext()->getOrLoadDialect(); tfl_dialect_ = module.getContext() ->getOrLoadDialect(); - stablehlo_dialect_ = - module.getContext() - ->getOrLoadDialect(); vhlo_dialect_ = module.getContext()->getOrLoadDialect(); // Right now the TF executor dialect is still needed to build NodeDef. @@ -834,7 +833,6 @@ class Translator { // dialect is not registered. const Dialect* tf_dialect_; const Dialect* tfl_dialect_; - const Dialect* stablehlo_dialect_; const Dialect* vhlo_dialect_; // The failed ops during legalization. @@ -1996,35 +1994,6 @@ std::optional> Translator::BuildOperator( return offset; } - // EXPERIMENTAL: If the source is in stablehlo dialect, also create them as - // builtin ops - if (dialect == stablehlo_dialect_) { - // for stablehlo ops with kernels, we directly serialize them whenever - // possible - if (auto shlo_op = llvm::dyn_cast(inst)) { - return BuildStablehloScatterOp(shlo_op, operands, results); - } - if (auto shlo_op = - llvm::dyn_cast(inst)) { - return BuildStablehloRngBitGeneratorOp(shlo_op, operands, results); - } - if (auto shlo_op = llvm::dyn_cast(inst)) { - return BuildStablehloGatherOp(shlo_op, operands, results); - } - if (auto shlo_op = llvm::dyn_cast(inst)) { - return BuildStablehloReduceWindowOp(shlo_op, operands, results); - } - if (auto shlo_op = llvm::dyn_cast(inst)) { - return BuildStablehloPadOp(shlo_op, operands, results); - } - if (auto shlo_op = llvm::dyn_cast(inst)) { - return BuildStablehloOperatorwithoutOptions( - shlo_op, operands, results, tflite::BuiltinOperator_STABLEHLO_ADD); - } - return inst->emitOpError("is not part of the stablehlo support yet."), - std::nullopt; - } - if (dialect == vhlo_dialect_) { mlir::VhloToStablehloTypeConverter vhlo_type_converter; if (auto vhlo_op = llvm::dyn_cast(inst)) { diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 55388c86dfc7bf..481f5573058b8c 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -3926,10 +3926,10 @@ def TFL_CastOp : TFL_Op<"cast", [ }]; let arguments = (ins - TFL_TensorOf<[F16, F32, F64, I1, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$input + TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$input ); - let results = (outs TFL_TensorOf<[F16, F32, F64, I1, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$output); + let results = (outs TFL_TensorOf<[F16, BF16, F32, F64, I1, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$output); // TFLite's cast op does not utilize CastOptions, instead derives types // from the TfLiteTensors. diff --git a/tensorflow/compiler/mlir/lite/metrics/BUILD b/tensorflow/compiler/mlir/lite/metrics/BUILD index dfdb63ce59ef5c..6218a2fb30a829 100644 --- a/tensorflow/compiler/mlir/lite/metrics/BUILD +++ b/tensorflow/compiler/mlir/lite/metrics/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -33,6 +33,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], ) @@ -43,11 +44,15 @@ tf_cc_test( "testdata/strided_slice.mlir", ], deps = [ + ":error_collector", ":error_collector_inst", ":types_util", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:test", + "//tensorflow/core/platform:errors", "//tensorflow/core/platform:resource_loader", + "//tensorflow/lite/python/metrics:converter_error_data_proto_cc", + "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", diff --git a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.cc b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.cc index 9a6c173f8c4f9d..6e31d8cb21f29a 100644 --- a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.cc +++ b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.cc @@ -21,7 +21,11 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_split.h" #include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/metrics/error_collector.h" +#include "tensorflow/compiler/mlir/lite/metrics/types_util.h" namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h index 322ec2e852d8cc..b5d66c622ab389 100644 --- a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h +++ b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h @@ -20,7 +20,9 @@ limitations under the License. #include #include +#include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/PassInstrumentation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/metrics/error_collector.h" diff --git a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc index ee433b0ded933c..f7d20783b6ea81 100644 --- a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc +++ b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc @@ -23,19 +23,29 @@ limitations under the License. #include #include -#include "llvm/Support/MemoryBuffer.h" +#include "absl/status/statusor.h" +#include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/metrics/error_collector.h" #include "tensorflow/compiler/mlir/lite/metrics/types_util.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/lite/python/metrics/converter_error_data.pb.h" #include "tsl/platform/statusor.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/metrics/types_util.cc b/tensorflow/compiler/mlir/lite/metrics/types_util.cc index 96a167b3254ba6..b47347ceb03827 100644 --- a/tensorflow/compiler/mlir/lite/metrics/types_util.cc +++ b/tensorflow/compiler/mlir/lite/metrics/types_util.cc @@ -16,8 +16,11 @@ limitations under the License. #include +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "tensorflow/lite/python/metrics/converter_error_data.pb.h" namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 8cf08dea534d5c..3e50192fa0640d 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -204,6 +204,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer( pass_config.legalize_custom_tensor_list_ops = toco_flags.legalize_custom_tensor_list_ops(); pass_config.enable_stablehlo_quantizer = toco_flags.has_quantization_config(); + pass_config.enable_composite_direct_lowering = + toco_flags.enable_composite_direct_lowering(); if (toco_flags.qdq_conversion_mode() == "STATIC") { pass_config.quant_specs.qdq_conversion_mode = diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD index 727fb03d833964..a6d6c61444548e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD @@ -1,6 +1,6 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -77,24 +77,22 @@ cc_library( srcs = [ "ConvertConst.cc", "ConvertSimQuant.cc", - "FakeQuantSupport.cc", "QuantOps.cc", "QuantizeUtils.cc", - "UniformSupport.cc", ], hdrs = [ - "FakeQuantSupport.h", "Passes.h", "QuantOps.h", "QuantizeUtils.h", - "UniformSupport.h", ], compatible_with = get_compatible_with_portable(), deps = [ ":QuantOpsIncGen", ":QuantPassIncGen", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc index ae9b67e9e60af6..3de159a1414429 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/ir/Passes.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h" -#include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h" using namespace mlir; using namespace mlir::quantfork; diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc index f64c400d4fb155..e99addc5b5f8a5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc @@ -16,10 +16,10 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/Passes.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" using namespace mlir; using namespace mlir::quantfork; diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc b/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc index 67c1c7d9284f2b..919c711272b2c1 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc @@ -19,7 +19,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" using namespace mlir; using namespace mlir::quantfork; diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index ad7c1905440297..66df4f528aa43d 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -88,9 +88,11 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD index df286611f3e356..f96d4961e733b4 100644 --- a/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD @@ -18,12 +18,14 @@ cc_library( "//tensorflow/compiler/mlir/lite/stablehlo:tf_stablehlo", "//tensorflow/compiler/mlir/quantization/stablehlo:passes", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:config", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_freeze_variables", "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc index ccba41d07e103b..08f5ecd4851b7e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" @@ -41,6 +43,7 @@ namespace tensorflow { namespace { using ::mlir::quant::stablehlo::StaticRangePtqComponent; +using ::stablehlo::quantization::PopulateDefaults; using ::stablehlo::quantization::QuantizationConfig; using ::tensorflow::SignatureDef; using ::tensorflow::quantization::PyFunctionLibrary; @@ -79,7 +82,7 @@ absl::StatusOr RunQuantization( const SavedModelBundle* saved_model_bundle, const absl::string_view saved_model_dir, const std::unordered_set& saved_model_tags, - QuantizationConfig& quantization_config, + const QuantizationConfig& quantization_config, const PyFunctionLibrary* quantization_py_function_lib, mlir::ModuleOp module_op) { if (saved_model_bundle == nullptr) { @@ -94,10 +97,11 @@ absl::StatusOr RunQuantization( "be nullptr."); } - if (!quantization_config.has_calibration_options()) { - *quantization_config.mutable_calibration_options() = - mlir::quant::stablehlo::GetDefaultCalibrationOptions(); - } + LOG(INFO) << "User-provided quantization config: " + << quantization_config.DebugString(); + const QuantizationConfig updated_config = + ExpandPresets(PopulateDefaults(quantization_config)); + LOG(INFO) << "Updated quantization config: " << updated_config.DebugString(); const absl::flat_hash_map signature_def_map = GetSignatureDefMapFromBundle(*saved_model_bundle); @@ -131,8 +135,9 @@ absl::StatusOr RunQuantization( module_op.getContext(), quantization_py_function_lib, saved_model_dir, /*signature_keys=*/exported_names, saved_model_tags, signature_def_map, GetFunctionAliases(*saved_model_bundle)); - const absl::StatusOr quantized_module_op = - static_range_ptq_component.Run(module_op, quantization_config); + + absl::StatusOr quantized_module_op = + static_range_ptq_component.Run(module_op, updated_config); if (!quantized_module_op.ok()) { return absl::InternalError("Failed to run quantization. Status msg: " + quantized_module_op.status().ToString()); diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h index ef6496315e8e61..c55d59cad0f1a0 100644 --- a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h @@ -50,7 +50,7 @@ absl::StatusOr RunQuantization( const SavedModelBundle* saved_model_bundle, absl::string_view saved_model_dir, const std::unordered_set& saved_model_tags, - stablehlo::quantization::QuantizationConfig& quantization_config, + const stablehlo::quantization::QuantizationConfig& quantization_config, const tensorflow::quantization::PyFunctionLibrary* quantization_py_function_lib, mlir::ModuleOp module_op); diff --git a/tensorflow/compiler/mlir/lite/sparsity/BUILD b/tensorflow/compiler/mlir/lite/sparsity/BUILD index 4f2e681a986f65..fce754995766d5 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/BUILD +++ b/tensorflow/compiler/mlir/lite/sparsity/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 92b6d2c9abb7b3..bd83f16de105f8 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -33,30 +33,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "stablehlo_tfl", - srcs = [ - "transforms/stablehlo_tfl_pass.cc", - ], - hdrs = [ - "transforms/stablehlo_tfl_pass.h", - ], - copts = [ - "-Ithird_party", - ], - deps = [ - "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "@flatbuffers", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - "@stablehlo//:stablehlo_ops", - ], - alwayslink = 1, -) - cc_library( name = "stablehlo_util", srcs = [ @@ -110,6 +86,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", ], @@ -133,6 +110,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:stablehlo_ops", ], @@ -162,10 +140,12 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", "@local_xla//xla/mlir_hlo:mhlo_passes", + "@local_xla//xla/mlir_hlo:type_conversion", "@stablehlo//:chlo_ops", "@stablehlo//:register", ], @@ -213,6 +193,7 @@ cc_library( ":drop_savedmodel_semantics", ":fold_broadcast_pass", ":fuse_convolution_pass", + ":legalize_stablehlo_custom_call_to_composite", ":legalize_tf_xla_call_module_to_stablehlo_pass", ":optimize", ":rename_entrypoint_to_main", @@ -337,6 +318,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", ], @@ -361,6 +343,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", ], @@ -389,6 +372,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", ], @@ -441,6 +425,7 @@ cc_library( "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_serialization", @@ -449,6 +434,32 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "legalize_stablehlo_composite_to_tfl_custom", + srcs = [ + "transforms/legalize_stablehlo_composite_to_tfl_custom.cc", + ], + hdrs = [ + "transforms/passes.h", + "transforms/passes.h.inc", + ], + copts = [ + "-Ithird_party", + ], + deps = [ + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "@flatbuffers", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:stablehlo_ops", + ], + alwayslink = 1, +) + cc_library( name = "legalize_stablehlo_to_vhlo_pass", srcs = [ @@ -492,6 +503,31 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "legalize_stablehlo_custom_call_to_composite", + srcs = [ + "transforms/legalize_stablehlo_custom_call_to_composite.cc", + ], + hdrs = [ + "transforms/passes.h", + "transforms/passes.h.inc", + ], + copts = [ + "-Ithird_party", + ], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@stablehlo//:stablehlo_ops", + ], + alwayslink = 1, +) + cc_library( name = "optimize", srcs = [ @@ -509,6 +545,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", ], @@ -648,6 +685,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", "@stablehlo//:broadcast_utils", @@ -655,6 +693,50 @@ cc_library( ], ) +cc_library( + name = "composite_lowering", + srcs = [ + "transforms/composite_lowering_pass.cc", + ], + hdrs = [ + "transforms/passes.h", + ], + copts = [ + "-Ithird_party", + ], + deps = [ + ":composite_lowering_inc_gen", + ":passes_inc_gen", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla/mlir_hlo", + ], + alwayslink = True, +) + +gentbl_cc_library( + name = "composite_lowering_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + ["-gen-rewriters"], + "transforms/generated_composite_lowering.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "transforms/composite_lowering_patterns.td", + deps = [ + "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncTdFiles", + "@local_xla//xla/mlir_hlo:hlo_ops_td_files", + ], +) + tf_cc_binary( name = "odml_to_stablehlo", srcs = [ @@ -667,7 +749,6 @@ tf_cc_binary( deps = [ ":check_accepted_ops_pass", ":op_stat_pass", - ":stablehlo_tfl", ":stablehlo_util", ":transforms", "//tensorflow/cc/saved_model:loader", @@ -675,7 +756,6 @@ tf_cc_binary( "//tensorflow/compiler/mlir:passes", "//tensorflow/compiler/mlir/lite:flatbuffer_export", "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", - "//tensorflow/compiler/mlir/lite/stablehlo/serializer:flatbuffer_export", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton_impl", @@ -709,11 +789,12 @@ tf_cc_binary( ":compose_uniform_quantized_type_pass", ":fold_broadcast_pass", ":fuse_convolution_pass", + ":legalize_stablehlo_composite_to_tfl_custom", + ":legalize_stablehlo_custom_call_to_composite", ":legalize_stablehlo_to_vhlo_pass", ":legalize_tf_xla_call_module_to_stablehlo_pass", ":optimize", ":passes_inc_gen", - ":stablehlo_tfl", ":tf_legalize_hlo", ":tf_stablehlo", ":tfl_legalize_hlo", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc index dfcb9de5cc717a..f1d6b237ac2ef6 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc @@ -50,10 +50,8 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/check_accepted_ops_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" @@ -191,17 +189,6 @@ tensorflow::StatusOr> ImportSavedModelOrMLIR( saved_model_bundle); } -tensorflow::Status ConvertStableHLOToFlatbuffer(mlir::ModuleOp module, - std::string* flatbuffer_str) { - mlir::odml::FlatbufferExportOptions options; - if (!mlir::odml::MlirToFlatBufferTranslateFunction(module, options, - flatbuffer_str)) { - return tensorflow::errors::Aborted("Unable to export flatbuffer"); - } - - return ::tensorflow::OkStatus(); -} - tensorflow::Status ExportModule(mlir::ModuleOp module, const std::string& output_filename, bool elide_large_elements_attrs) { @@ -212,20 +199,6 @@ tensorflow::Status ExportModule(mlir::ModuleOp module, return tensorflow::errors::Aborted("Unable to write to output path."); } - // Export TFLite Flatbuffer as output - if (export_type == "tflite") { - std::string flatbuffer_str; - auto status = - mlir::odml::ConvertStableHLOToFlatbuffer(module, &flatbuffer_str); - if (!status.ok()) { - return status; - } - - output->os() << flatbuffer_str; - output->keep(); - return ::tensorflow::OkStatus(); - } - // Export StableHLO MLIR as output std::string result; llvm::raw_string_ostream os(result); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/serializer/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/serializer/BUILD deleted file mode 100644 index a93ec34c1bfa81..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/serializer/BUILD +++ /dev/null @@ -1,64 +0,0 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//visibility:public", - ], - licenses = ["notice"], -) - -cc_library( - name = "flatbuffer_translator", - srcs = [ - "flatbuffer_translator.cc", - ], - hdrs = [ - "flatbuffer_operator.h", - "flatbuffer_translator.h", - ], - compatible_with = get_compatible_with_portable(), - deps = [ - "//tensorflow/compiler/mlir:op_or_arg_name_mapper", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:convert_tensor", - "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:logging", - "//tensorflow/lite/stablehlo/schema:schema_fbs", - "//tensorflow/lite/toco:toco_flags_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@flatbuffers", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TranslateLib", - "@local_xla//xla:statusor", - "@stablehlo//:stablehlo_ops", - ], -) - -cc_library( - name = "flatbuffer_export", - srcs = [ - "flatbuffer_export.cc", - ], - hdrs = ["flatbuffer_export.h"], - compatible_with = get_compatible_with_portable(), - deps = [ - ":flatbuffer_translator", - "//tensorflow/compiler/mlir:op_or_arg_name_mapper", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:logging", - "//tensorflow/lite/toco:toco_flags_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@llvm-project//mlir:IR", - ], -) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_export.cc deleted file mode 100644 index a35f7821e68bbd..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_export.cc +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_export.h" - -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_translator.h" -#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/platform/logging.h" - -namespace mlir { -namespace odml { - -bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module, - const FlatbufferExportOptions& options, - std::string* serialized_flatbuffer) { - auto maybe_translated = Translator::Translate( - module, options.toco_flags, options.saved_model_tags, - options.op_or_arg_name_mapper, options.metadata); - if (!maybe_translated) return false; - *serialized_flatbuffer = std::move(*maybe_translated); - return true; -} - -} // namespace odml -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_export.h b/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_export.h deleted file mode 100644 index ae980f6f6522ad..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_export.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_SERIALIZER_FLATBUFFER_EXPORT_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_SERIALIZER_FLATBUFFER_EXPORT_H_ - -#include -#include -#include - -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" -#include "tensorflow/lite/toco/toco_flags.pb.h" - -namespace mlir { -namespace odml { - -// Options for exporting to Flatbuffer. -struct FlatbufferExportOptions { - // TocoFlags proto. The following fields are migrated. - // bool emit_builtin_tflite_ops -> !toco_flags.force_select_tf_ops() - // bool emit_select_tf_ops -> toco_flags.enable_select_tf_ops() - // bool emit_custom_ops -> toco_flags.allow_custom_ops() - // bool allow_all_select_tf_ops -> toco_flags.allow_all_select_tf_ops() - // std::set<> select_user_tf_ops -> toco_flags.select_user_tf_ops() - toco::TocoFlags toco_flags; - // When exporting from SavedModel, this will have the requested tags. - std::unordered_set saved_model_tags; - // Metadata key/value pairs to write to the flatbuffer. - std::map metadata; - // OpOrArgNameMapper to convert location of the op to name in flatbuffer. - // If not set, a default mapper will be used. - tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper = nullptr; -}; - -// Translates the given MLIR `module` into a FlatBuffer and stores the -// serialized flatbuffer into the string. -// Returns true on successful exporting, false otherwise. -bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module, - const FlatbufferExportOptions& options, - std::string* serialized_flatbuffer); - -} // namespace odml -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_SERIALIZER_FLATBUFFER_EXPORT_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_operator.h b/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_operator.h deleted file mode 100644 index 453f7f508c39d6..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_operator.h +++ /dev/null @@ -1,173 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// prototype for stablehlo serialization, WIP -// WARNING: converting to stablehlo file is experimental feature, and no runtime -// support is provided - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_SERIALIZER_FLATBUFFER_OPERATOR_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_SERIALIZER_FLATBUFFER_OPERATOR_H_ - -#include -#include -#include - -#include "llvm/ADT/APInt.h" -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project - -namespace mlir { -namespace odml { - -// TODO(zichuanwei@): support float16/bfloat16 & int4 - -// Function calls with a non-specialized type will result to a linker error. -template -inline std::vector GetVector(DenseElementsAttr elements); - -// TODO(zichuanwei@): for each type, we need to make sure the element type -// matches the expected type otherwise an error should be thrown, but for now -// we're just returning empty vector -template <> -inline std::vector GetVector(DenseElementsAttr elements) { - auto type = elements.getType(); - auto elemType = type.getElementType(); - if (elemType.isSignlessInteger(1)) { - auto vec = llvm::to_vector( - llvm::map_range(elements.getValues(), - [&](bool value) -> uint8_t { return value ? 1 : 0; })); - return std::vector(vec.begin(), vec.end()); - } - - return std::vector(); -} - -template <> -inline std::vector GetVector(DenseElementsAttr elements) { - auto type = elements.getType(); - auto elemType = type.getElementType(); - if (elemType.isSignlessInteger(8)) { - auto vec = llvm::to_vector(llvm::map_range( - elements.getValues(), - [&](APInt value) -> int8_t { return value.getSExtValue(); })); - return std::vector(vec.begin(), vec.end()); - } - - return std::vector(); -} - -template <> -inline std::vector GetVector(DenseElementsAttr elements) { - auto type = elements.getType(); - auto elemType = type.getElementType(); - if (elemType.isSignlessInteger(16)) { - auto vec = llvm::to_vector(llvm::map_range( - elements.getValues(), - [&](APInt value) -> int16_t { return value.getSExtValue(); })); - return std::vector(vec.begin(), vec.end()); - } - - return std::vector(); -} - -template <> -inline std::vector GetVector(DenseElementsAttr elements) { - auto type = elements.getType(); - auto elemType = type.getElementType(); - if (elemType.isSignlessInteger(32)) { - auto vec = llvm::to_vector(llvm::map_range( - elements.getValues(), - [&](APInt value) -> int32_t { return value.getSExtValue(); })); - return std::vector(vec.begin(), vec.end()); - } - - return std::vector(); -} - -template <> -inline std::vector GetVector(DenseElementsAttr elements) { - auto type = elements.getType(); - auto elemType = type.getElementType(); - if (elemType.isSignlessInteger(64)) { - auto vec = llvm::to_vector(llvm::map_range( - elements.getValues(), - [&](APInt value) -> int64_t { return value.getSExtValue(); })); - return std::vector(vec.begin(), vec.end()); - } - - return std::vector(); -} - -template <> -inline std::vector GetVector(DenseElementsAttr elements) { - auto type = elements.getType(); - auto elemType = type.getElementType(); - if (elemType.isF32()) { - auto vec = llvm::to_vector(llvm::map_range( - elements.getValues(), - [&](APFloat value) -> float { return value.convertToFloat(); })); - return std::vector(vec.begin(), vec.end()); - } - - return std::vector(); -} - -template <> -inline std::vector GetVector(DenseElementsAttr elements) { - auto type = elements.getType(); - auto elemType = type.getElementType(); - if (elemType.isF64()) { - auto vec = llvm::to_vector(llvm::map_range( - elements.getValues(), - [&](APFloat value) -> double { return value.convertToFloat(); })); - return std::vector(vec.begin(), vec.end()); - } - - return std::vector(); -} - -// Handles the case when the DenseElementsAttr doesn't exist, and when it -// doesn't returns a vector of length `default_size` all with the same value -// `default_value`. -template -static inline std::vector GetOptionalVector( - std::optional elements, int64_t default_size, - int64_t default_value) { - if (elements.has_value()) { - return GetVector(elements.value()); - } - return std::vector(default_size, default_value); -} - -// Handles the case when the SmallVector doesn't exist, and when it -// doesn't returns a vector of length `default_size` all with the same value -// `default_value`. -template -static inline std::vector GetOptionalVector( - std::optional> values, int64_t default_size, - int64_t default_value) { - if (values.has_value()) { - return std::vector(values->begin(), values->end()); - } - return std::vector(default_size, default_value); -} - -} // namespace odml -} // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_SERIALIZER_FLATBUFFER_OPERATOR_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_translator.cc b/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_translator.cc deleted file mode 100644 index fb5e2fadab907b..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_translator.cc +++ /dev/null @@ -1,904 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// prototype for stablehlo serialization, WIP -// WARNING: converting to stablehlo file is experimental feature, and no runtime -// support is provided - -#include "tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_translator.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_operator.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" -#include "xla/statusor.h" -#include "tensorflow/lite/stablehlo/schema/schema_generated.h" - -#define kStablehloOptionalTensor (-1) - -using llvm::isa; -using llvm::StringRef; -using llvm::Twine; -using mlir::ElementsAttr; -using mlir::ModuleOp; -using mlir::Operation; -using mlir::StringAttr; -using mlir::TensorType; -using mlir::Value; -using mlir::func::FuncOp; -using tensorflow::OpOrArgLocNameMapper; -using tensorflow::OpOrArgNameMapper; -using xla::StatusOr; - -namespace mlir { -namespace odml { - -// TODO(b/267689361) this and the following functions should be automatically -// generated similar to operator_converters.inc in tflite -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> CreateAddOperator( - mlir::stablehlo::AddOp& hlo_op, flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, const std::vector& operands, - const std::vector& results) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - return ::stablehlo::flatbuf::CreateOperator(*fbb, opcode_index, inputs, - outputs); -} - -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> -CreateReshapeOperator(mlir::stablehlo::ReshapeOp& hlo_op, - flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, - const std::vector& operands, - const std::vector& results) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - return ::stablehlo::flatbuf::CreateOperator(*fbb, opcode_index, inputs, - outputs); -} - -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> CreateDivOperator( - mlir::stablehlo::DivOp& hlo_op, flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, const std::vector& operands, - const std::vector& results) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - return ::stablehlo::flatbuf::CreateOperator(*fbb, opcode_index, inputs, - outputs); -} - -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> -CreateSubtractOperator(mlir::stablehlo::SubtractOp& hlo_op, - flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, - const std::vector& operands, - const std::vector& results) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - return ::stablehlo::flatbuf::CreateOperator(*fbb, opcode_index, inputs, - outputs); -} - -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> CreateMulOperator( - mlir::stablehlo::MulOp hlo_op, flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, const std::vector& operands, - const std::vector& results) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - return ::stablehlo::flatbuf::CreateOperator(*fbb, opcode_index, inputs, - outputs); -} -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> CreateMaxOperator( - mlir::stablehlo::MaxOp& hlo_op, flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, const std::vector& operands, - const std::vector& results) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - return ::stablehlo::flatbuf::CreateOperator(*fbb, opcode_index, inputs, - outputs); -} - -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> -CreateConvertOperator(mlir::stablehlo::ConvertOp& hlo_op, - flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, - const std::vector& operands, - const std::vector& results) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - return ::stablehlo::flatbuf::CreateOperator(*fbb, opcode_index, inputs, - outputs); -} - -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> CreateDotOperator( - mlir::stablehlo::DotOp& hlo_op, flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, const std::vector& operands, - const std::vector& results) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - return ::stablehlo::flatbuf::CreateOperator(*fbb, opcode_index, inputs, - outputs); -} - -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> CreateClampOperator( - mlir::stablehlo::ClampOp& hlo_op, flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, const std::vector& operands, - const std::vector& results) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - return ::stablehlo::flatbuf::CreateOperator(*fbb, opcode_index, inputs, - outputs); -} - -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> -CreateLogisticOperator(mlir::stablehlo::LogisticOp& hlo_op, - flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, - const std::vector& operands, - const std::vector& results) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - return ::stablehlo::flatbuf::CreateOperator(*fbb, opcode_index, inputs, - outputs); -} - -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> -CreateConcatenateOperator(mlir::stablehlo::ConcatenateOp& hlo_op, - flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, - const std::vector& operands, - const std::vector& results) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - auto options = ::stablehlo::flatbuf::CreateConcatenateOptions( - *fbb, hlo_op.getDimension()); - - return ::stablehlo::flatbuf::CreateOperator( - *fbb, opcode_index, inputs, outputs, - ::stablehlo::flatbuf::OperatorOptions_ConcatenateOptions, - options.Union()); -} - -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> -CreateConvolutionOperator(mlir::stablehlo::ConvolutionOp& hlo_op, - flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, - const std::vector& operands, - const std::vector& results) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - // converting from mlir struct to std - std::vector window_strides_vec = - GetOptionalVector(hlo_op.getWindowStrides(), 0, 0); - std::vector padding_vec = - GetOptionalVector(hlo_op.getPadding(), 0, 0); - std::vector lhs_dilation_vec = - GetOptionalVector(hlo_op.getLhsDilation(), 0, 0); - std::vector rhs_dilation_vec = - GetOptionalVector(hlo_op.getRhsDilation(), 0, 0); - std::vector window_reversal_vec = - GetOptionalVector(hlo_op.getWindowReversal(), 0, 0); - const int64_t feature_group_count = hlo_op.getFeatureGroupCount(); - const int64_t batch_group_count = hlo_op.getBatchGroupCount(); - - auto conv_dimension_numbers = hlo_op.getDimensionNumbersAttr(); - - std::vector input_spatial_dimensions_vec = - conv_dimension_numbers.getInputSpatialDimensions().vec(); - std::vector kernel_spatial_dimensions_vec = - conv_dimension_numbers.getKernelSpatialDimensions().vec(); - std::vector output_spatial_dimensions_vec = - conv_dimension_numbers.getOutputSpatialDimensions().vec(); - const int64_t input_batch_dimension = - conv_dimension_numbers.getInputBatchDimension(); - const int64_t input_feature_dimension = - conv_dimension_numbers.getInputFeatureDimension(); - const int64_t kernel_input_feature_dimension = - conv_dimension_numbers.getKernelInputFeatureDimension(); - const int64_t kernel_output_feature_dimension = - conv_dimension_numbers.getKernelOutputFeatureDimension(); - const int64_t output_batch_dimension = - conv_dimension_numbers.getOutputBatchDimension(); - const int64_t output_feature_dimension = - conv_dimension_numbers.getOutputFeatureDimension(); - - // serialize all vectors to flatbuffer - auto window_strides = fbb->CreateVector(window_strides_vec); - auto padding = fbb->CreateVector(padding_vec); - auto lhs_dilation = fbb->CreateVector(lhs_dilation_vec); - auto rhs_dilation = fbb->CreateVector(rhs_dilation_vec); - auto input_spatial_dimensions = - fbb->CreateVector(input_spatial_dimensions_vec); - auto kernel_spatial_dimensions = - fbb->CreateVector(kernel_spatial_dimensions_vec); - auto output_spatial_dimensions = - fbb->CreateVector(output_spatial_dimensions_vec); - auto window_reversal = fbb->CreateVector(window_reversal_vec); - - auto options = ::stablehlo::flatbuf::CreateConvolutionOptions( - *fbb, window_strides, padding, lhs_dilation, rhs_dilation, - window_reversal, input_batch_dimension, input_feature_dimension, - input_spatial_dimensions, kernel_input_feature_dimension, - kernel_output_feature_dimension, kernel_spatial_dimensions, - output_batch_dimension, output_feature_dimension, - output_spatial_dimensions, feature_group_count, batch_group_count); - - return ::stablehlo::flatbuf::CreateOperator( - *fbb, opcode_index, inputs, outputs, - ::stablehlo::flatbuf::OperatorOptions_ConvolutionOptions, - options.Union()); -} - -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> -CreateReduceWindowOperator(mlir::stablehlo::ReduceWindowOp& hlo_op, - flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, - const std::vector& operands, - const std::vector& results, - const int subgraph_idx) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - // TODO(zichuanwei@): instead of create these vectors let's just create - // Flatbuffers vector directly - std::vector window_dimension_vec( - GetOptionalVector(hlo_op.getWindowDimensions(), 0, 0)); - std::vector window_strides_vec( - GetOptionalVector(hlo_op.getWindowStrides(), 0, 0)); - std::vector base_dilations_vec( - GetOptionalVector(hlo_op.getBaseDilations(), 0, 0)); - std::vector window_dilations_vec( - GetOptionalVector(hlo_op.getWindowDilations(), 0, 0)); - std::vector padding_vec( - GetOptionalVector(hlo_op.getPadding(), 0, 0)); - - auto window_dimension = fbb->CreateVector(window_dimension_vec); - auto window_strides = fbb->CreateVector(window_strides_vec); - auto base_dilations = fbb->CreateVector(base_dilations_vec); - auto window_dilations = fbb->CreateVector(window_dilations_vec); - auto padding = fbb->CreateVector(padding_vec); - - auto options = ::stablehlo::flatbuf::CreateReduceWindowOptions( - *fbb, window_dimension, window_strides, base_dilations, window_dilations, - padding, subgraph_idx); - - return ::stablehlo::flatbuf::CreateOperator( - *fbb, opcode_index, inputs, outputs, - ::stablehlo::flatbuf::OperatorOptions_ReduceWindowOptions, - options.Union()); -} - -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> -CreateBroadcastInDimOperator(mlir::stablehlo::BroadcastInDimOp& hlo_op, - flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, - const std::vector& operands, - const std::vector& results) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - auto dims = hlo_op.getBroadcastDimensions(); - auto broadcast_dimension = - fbb->CreateVector(std::vector(dims.begin(), dims.end())); - - auto options = ::stablehlo::flatbuf::CreateBroadcastInDimOptions( - *fbb, broadcast_dimension); - - return ::stablehlo::flatbuf::CreateOperator( - *fbb, opcode_index, inputs, outputs, - ::stablehlo::flatbuf::OperatorOptions_BroadcastInDimOptions, - options.Union()); -} - -static flatbuffers::Offset<::stablehlo::flatbuf::Operator> -CreateResizeBilinearOperator(mlir::stablehlo::CustomCallOp& hlo_op, - flatbuffers::FlatBufferBuilder* fbb, - uint32_t opcode_index, - const std::vector& operands, - const std::vector& results) { - auto inputs = fbb->CreateVector(operands); - auto outputs = fbb->CreateVector(results); - - auto align_corners = - hlo_op->getAttr("align_corners").dyn_cast(); - assert(align_corners); - auto half_pixel_center = - hlo_op->getAttr("half_pixel_centers").dyn_cast(); - assert(half_pixel_center); - - auto options = ::stablehlo::flatbuf::CreateResizeBilinearOptions( - *fbb, align_corners.getValue(), half_pixel_center.getValue()); - - return ::stablehlo::flatbuf::CreateOperator( - *fbb, opcode_index, inputs, outputs, - ::stablehlo::flatbuf::OperatorOptions_ResizeBilinearOptions, - options.Union()); -} - -std::optional> -CreateFlatBufferOperator(mlir::Operation* op, uint32_t opcode_index, - const std::vector& operands, - const std::vector& results, - flatbuffers::FlatBufferBuilder* fbb, - int subgraph_idx = 0) { - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateAddOperator(hlo_op, fbb, opcode_index, operands, results); - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateDotOperator(hlo_op, fbb, opcode_index, operands, results); - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateLogisticOperator(hlo_op, fbb, opcode_index, operands, results); - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateDivOperator(hlo_op, fbb, opcode_index, operands, results); - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateSubtractOperator(hlo_op, fbb, opcode_index, operands, results); - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateMulOperator(hlo_op, fbb, opcode_index, operands, results); - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateMaxOperator(hlo_op, fbb, opcode_index, operands, results); - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateReshapeOperator(hlo_op, fbb, opcode_index, operands, results); - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateConvolutionOperator(hlo_op, fbb, opcode_index, operands, - results); - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateReduceWindowOperator(hlo_op, fbb, opcode_index, operands, - results, subgraph_idx); - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateBroadcastInDimOperator(hlo_op, fbb, opcode_index, operands, - results); - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateResizeBilinearOperator(hlo_op, fbb, opcode_index, operands, - results); - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateClampOperator(hlo_op, fbb, opcode_index, operands, results); - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateConcatenateOperator(hlo_op, fbb, opcode_index, operands, - results); - if (auto hlo_op = llvm::dyn_cast(op)) - return CreateConvertOperator(hlo_op, fbb, opcode_index, operands, results); - return std::nullopt; -} - -static absl::StatusOr<::stablehlo::flatbuf::DataType> GetDataType( - Type type, bool is_signed = true) { - if (type.isF16()) return ::stablehlo::flatbuf::DataType_FLOAT16; - if (type.isF32()) return ::stablehlo::flatbuf::DataType_FLOAT32; - if (type.isF64()) return ::stablehlo::flatbuf::DataType_FLOAT64; - if (type.isSignlessInteger(8)) return ::stablehlo::flatbuf::DataType_INT8; - if (type.isSignlessInteger(16)) return ::stablehlo::flatbuf::DataType_INT16; - if (type.isSignlessInteger(32)) return ::stablehlo::flatbuf::DataType_INT32; - if (type.isSignlessInteger(64)) return ::stablehlo::flatbuf::DataType_INT64; - if (type.isUnsignedInteger(8)) return ::stablehlo::flatbuf::DataType_UINT8; - if (type.isUnsignedInteger(16)) return ::stablehlo::flatbuf::DataType_UINT16; - if (type.isUnsignedInteger(32)) return ::stablehlo::flatbuf::DataType_UINT32; - if (type.isUnsignedInteger(64)) return ::stablehlo::flatbuf::DataType_UINT64; - std::string type_str; - llvm::raw_string_ostream str_stream(type_str); - str_stream << type; - LOG(ERROR) << "unsupported datatype" << type_str; - return tensorflow::errors::InvalidArgument("unsupported datatype" + type_str); -} - -std::optional<::stablehlo::flatbuf::OperatorCode> GetOpCode( - mlir::Operation* op) { - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_ADD; - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_DOT; - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_SUBTRACT; - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_DIVIDE; - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_LOGISTIC; - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_MULTIPLY; - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_MAXIMUM; - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_RESHAPE; - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_CONVOLUTION; - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_BROADCAST_IN_DIM; - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_REDUCE_WINDOW; - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_CLAMP; - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_CONCATENATE; - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_CONVERT; - - // For now we assume the incoming custom op is a resize_bilinear, it is - // expected any other custom op will cause the program to error out - if (isa(op)) - return ::stablehlo::flatbuf::OperatorCode_RESIZE_BILINEAR; - - op->emitError(Twine("unsupported op type " + op->getName().getStringRef())); - return std::nullopt; -} - -static bool IsConst(Operation* op) { - return isa(op); -} - -std::optional Translator::Translate( - ModuleOp module, const toco::TocoFlags& toco_flags, - const std::unordered_set& tags, - OpOrArgNameMapper* op_or_arg_name_mapper, - const std::map& metadata) { - OpOrArgLocNameMapper default_op_or_arg_name_mapper; - if (!op_or_arg_name_mapper) - op_or_arg_name_mapper = &default_op_or_arg_name_mapper; - // TODO(b/267689626): sanity checkers not implemented - Translator translator(module, toco_flags, tags, op_or_arg_name_mapper, - metadata); - return translator.TranslateInternal(); -} - -std::optional Translator::TranslateInternal() { - // A list of named regions in the module with main function being the first in - // the list. The main function is required as the first subgraph in the model - // is entry point for the model. - std::vector> named_regions; - named_regions.reserve(std::distance(module_.begin(), module_.end())); - - int subgraph_idx = 0; - - // Entry functions for signature defs. - std::vector entry_functions; - std::vector non_entry_functions; - FuncOp main_fn = module_.lookupSymbol("main"); - if (main_fn != nullptr) { - // Treat the main function as a signature def when the given main function - // contains on the tf.entry_function attribute. - auto attrs = - main_fn->getAttrOfType(tf_entry_function_); - if (attrs && !attrs.empty()) { - entry_functions.push_back(main_fn); - } else { - non_entry_functions.push_back(main_fn); - } - } - - // Walk over the module collection ops with functions and while ops. - module_.walk([&](FuncOp fn) { - if (main_fn == fn) return WalkResult::advance(); - auto attrs = fn->getAttrOfType("tf.entry_function"); - if (attrs && !attrs.empty()) { - entry_functions.push_back(fn); - } else { - non_entry_functions.push_back(fn); - } - return WalkResult::advance(); - }); - - // collect all reduce window ops, this is only a temporary hack - // in the future, we should have a function to walk over all ops that have - // regions contained, the logic in stablehlo is a bit different from tfl - // dialect in that all subgraphs in tflite a enclosed in func op where - // stablehlo op maintain their own regions - std::vector reduce_window; - module_.walk([&](mlir::stablehlo::ReduceWindowOp op) { - reduce_window.push_back(op); - return WalkResult::advance(); - }); - - // Assign the subgraph index. Among the given functions, it will put entry - // functions at the beginning of the list of the subgrahs. - for (auto fn : entry_functions) { - subgraph_index_map_[fn.getName().str()] = subgraph_idx++; - named_regions.emplace_back(fn.getName().str(), &fn.getBody()); - } - for (auto fn : non_entry_functions) { - subgraph_index_map_[fn.getName().str()] = subgraph_idx++; - named_regions.emplace_back(fn.getName().str(), &fn.getBody()); - } - - // add regions of reduce_window ops into subgraph map. the name will be - // stablehlo.reduce_window as mlir::region is not assicoate with a name - for (auto op : reduce_window) { - reduce_window_subgraph_map_[op] = subgraph_idx++; - named_regions.emplace_back(op.getOperationName().str(), &op.getBody()); - } - - // Build subgraph for each of the named regions. - std::vector> subgraphs; - subgraphs.reserve(named_regions.size()); - int first_failed_func = -1; - - // When we export each function in the module op, intentionally, we export the - // entry functions at the beginning of the subgraph list and the - // subgraph_index is the index in entry functions and at the same, is the - // index in the subgraph list. - int subgraph_index = 0; - for (const auto& it : llvm::enumerate(named_regions)) { - auto subgraph_or = - BuildSubGraph(it.value().first, it.value().second, subgraph_index); - if (!subgraph_or) { - if (first_failed_func == -1) - // Record the index of the first region that cannot be converted. - // Keep looping through all subgraphs in the module to make sure that - // we collect the list of missing ops from the entire module. - first_failed_func = it.index(); - } else { - subgraphs.push_back(*subgraph_or); - ++subgraph_index; - } - } - // TODO(b/267801705) : Add schema version - auto model = ::stablehlo::flatbuf::CreateModel( - builder_, 0, builder_.CreateVector(opcodes_), - builder_.CreateVector(subgraphs), builder_.CreateVector(buffers_)); - ::stablehlo::flatbuf::FinishModelBuffer(builder_, model); - // There is a limit of 2GB for a flatbuffer. - if (builder_.GetSize() > 2147483648) { - LOG(ERROR) << "Model size is bigger than 2gb"; - return std::nullopt; - } - - // Return serialized string for the built FlatBuffer. - return std::string(reinterpret_cast(builder_.GetBufferPointer()), - builder_.GetSize()); -} - -std::optional> -Translator::BuildTensor(Value value, const std::string& name, - unsigned buffer_idx) { - auto type = value.getType().cast(); - - auto check_shape = - [&](llvm::ArrayRef shape_ref) -> mlir::LogicalResult { - auto is_out_of_range = [](int64_t dim) { - return dim > std::numeric_limits::max(); - }; - - if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range)) - return mlir::emitError( - value.getLoc(), - "result shape dimensions out of 32 bit int type range"); - - return mlir::success(); - }; - - std::vector shape; - std::vector shape_signature; - auto* inst = value.getDefiningOp(); - - bool is_variable = !(inst && IsConst(inst)); - if (type.hasStaticShape()) { - llvm::ArrayRef shape_ref = type.getShape(); - if (mlir::failed(check_shape(shape_ref))) return std::nullopt; - - shape = std::vector(shape_ref.begin(), shape_ref.end()); - } else if (inst && IsConst(inst)) { - // Const op can have a result of dynamic shaped type (e.g. due to constant - // folding), but we can still derive the shape of a constant tensor for - // its attribute type. - - auto tensor_attr = inst->getAttr("value").cast(); - llvm::ArrayRef shape_ref = - tensor_attr.getType().cast().getShape(); - if (mlir::failed(check_shape(shape_ref))) return std::nullopt; - - shape = std::vector(shape_ref.begin(), shape_ref.end()); - } else if (type.hasRank()) { - llvm::ArrayRef shape_ref = type.getShape(); - if (mlir::failed(check_shape(shape_ref))) return std::nullopt; - - shape.reserve(shape_ref.size()); - for (auto& dim : shape_ref) { - // translate dynamic shapes from mlir to tfl values - shape.push_back( - dim == mlir::ShapedType::kDynamic ? 1 : static_cast(dim)); - shape_signature.push_back(static_cast( - dim == mlir::ShapedType::kDynamic ? tensorflow::kTFDynamicSize - : dim)); - } - } - - Type element_type = type.getElementType(); - auto status = GetDataType(element_type); - if (!status.ok()) return std::nullopt; - ::stablehlo::flatbuf::DataType data_type = GetDataType(element_type).value(); - - return ::stablehlo::flatbuf::CreateTensor( - builder_, builder_.CreateVector(shape), data_type, - (is_variable ? 0 : buffer_idx), builder_.CreateString(name)); -} - -void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { - auto dict_attr = fn->getAttrOfType(tf_entry_function_); - if (!dict_attr) return; - - llvm::SmallVector input_names; - llvm::SmallVector output_names; - if (auto str = dict_attr.get("inputs").dyn_cast_or_null()) { - str.getValue().split(input_names, ',', /*MaxSplit=*/-1, - /*KeepEmpty=*/false); - if (input_names.size() != fn.getNumArguments()) { - fn.emitWarning() << "invalid entry function specification"; - return; - } - for (const auto& it : llvm::enumerate(fn.getArguments())) { - name_mapper_.InitOpName(it.value(), input_names[it.index()].trim()); - } - *has_input_attr = true; - } - - if (auto str = - dict_attr.get("outputs").dyn_cast_or_null()) { - str.getValue().split(output_names, ',', /*MaxSplit=*/-1, - /*KeepEmpty=*/false); - auto term = fn.back().getTerminator(); - if (output_names.size() != term->getNumOperands()) { - fn.emitWarning() << "output names (" << output_names.size() - << ") != terminator operands (" << term->getNumOperands() - << ")"; - return; - } - for (const auto& it : llvm::enumerate(term->getOperands())) { - name_mapper_.InitOpName(it.value(), output_names[it.index()].trim()); - } - } -} - -std::string Translator::UniqueName(mlir::Value val) { - return std::string(name_mapper_.GetUniqueName(val)); -} - -std::optional> -Translator::BuildSubGraph(const std::string& name, Region* region, int index) { - bool has_input_attr = false; - if (auto fn = dyn_cast(region->getParentOp())) { - InitializeNamesFromAttribute(fn, &has_input_attr); - } - std::vector> tensors; - llvm::DenseMap tensor_index_map; - - // Builds tensor and buffer for argument or operation result. Returns false - // on failure. - auto build_tensor_and_buffer = [&](Value value, const int subgraph_index, - const std::string& tensor_name) { - // NoneType represents optional and may be skipped here. - if (value.getType().isa()) { - return true; - } - - tensor_index_map.insert({value, tensors.size()}); - tensor_index_map_[subgraph_index][tensor_name] = tensors.size(); - auto tensor_or = BuildTensor(value, tensor_name, buffers_.size()); - if (!tensor_or) return false; - tensors.push_back(*tensor_or); - - if (value.getDefiningOp()) { - auto buffer_or = BuildBuffer(value); - if (!buffer_or) return false; - buffers_.push_back(*buffer_or); - } else { - // TODO(b/267802872): Tflite will create a buffer entry for every tensor - // regardless constant or not. in stablehlo serialization, we don't plan - // to keep this behaviour - buffers_.push_back(empty_buffer_); - } - return true; - }; - - std::vector> operators; - - // Maps positions of operations in bb to positions in operators - llvm::DenseMap operation_index_to_operator_index; - std::vector operators_in_mlir; - auto& bb = region->front(); - - // Main function's arguments are first passed to `input` op so they don't - // have associated tensor and buffer. Build FlatBuffer tensor and buffer for - // other functions. - for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) { - mlir::BlockArgument arg = bb.getArgument(i); - std::string tensor_name; - if (has_input_attr) - tensor_name = std::string(name_mapper_.GetUniqueName(arg)); - if (tensor_name.empty()) tensor_name = absl::StrCat("arg", i); - if (!build_tensor_and_buffer(arg, index, tensor_name)) return std::nullopt; - } - - bool failed_once = false; - for (const auto& item : llvm::enumerate(bb)) { - Operation& inst = item.value(); - const int operation_index = item.index(); - if (inst.hasTrait()) break; - - for (auto val : inst.getResults()) { - std::string tensor_name = UniqueName(val); - // For "tfl.numeric_verify" op, the name is used to find out the original - // activation tensor rather than its own unique name in the visualization - // or debugging tools. - // auto builtin_code = GetOpCode(&inst); - if (!build_tensor_and_buffer(val, index, tensor_name)) - return std::nullopt; - } - - // Skip constant ops as they don't represent flatbuffer operator. - if (IsConst(&inst)) continue; - - // Fetch operand and result tensor indices. - std::vector results; - results.reserve(inst.getNumResults()); - for (auto result : inst.getResults()) { - results.push_back(tensor_index_map.lookup(result)); - } - Operation* real_inst = &inst; - std::vector operands; - operands.reserve(real_inst->getNumOperands()); - for (auto operand : real_inst->getOperands()) { - if (operand.getType().isa()) - operands.push_back(kStablehloOptionalTensor); - else - operands.push_back(tensor_index_map.lookup(operand)); - } - - if (auto flat_operator = BuildOperator(real_inst, operands, results)) { - operation_index_to_operator_index.try_emplace(operation_index, - operators.size()); - operators.push_back(*flat_operator); - operators_in_mlir.push_back(real_inst); - } else { - failed_once = true; - } - } - if (index + 1 > subgraph_op_inst_map_.size()) { - subgraph_op_inst_map_.resize(index + 1); - } - subgraph_op_inst_map_[index] = operators_in_mlir; - if (failed_once) return std::nullopt; - - // Get input and output tensor indices for the subgraph. - std::vector inputs, outputs; - for (auto arg : bb.getArguments()) { - inputs.push_back(tensor_index_map[arg]); - } - for (auto result : bb.getTerminator()->getOperands()) { - outputs.push_back(tensor_index_map[result]); - } - return ::stablehlo::flatbuf::CreateSubGraph( - builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs), - builder_.CreateVector(outputs), builder_.CreateVector(operators), - /*name=*/builder_.CreateString(name)); -} - -std::optional> -Translator::BuildBuffer(mlir::Value value) { - auto inst = value.getDefiningOp(); - ElementsAttr attr; - - if (auto cst = dyn_cast(inst)) { - // arith::ConstantOp have ElementAttr at this point due to validation of the - // TFLite module. - attr = cst.getValue().cast(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.getValue(); - } else { - return empty_buffer_; - } - - tensorflow::Tensor tensor; - auto status = tensorflow::ConvertToTensor(attr, &tensor); - if (!status.ok()) { - inst->emitError( - Twine("failed to convert value attribute to tensor with error: " + - status.ToString())); - return std::nullopt; - } - - absl::string_view tensor_data = tensor.tensor_data(); - auto buffer_data = builder_.CreateVector( - reinterpret_cast(tensor_data.data()), tensor_data.size()); - return ::stablehlo::flatbuf::CreateBuffer(builder_, buffer_data); -} - -uint32_t Translator::GetOpcodeIndex( - const std::string& op_name, ::stablehlo::flatbuf::OperatorCode op_code) { - auto it = opcode_index_map_.insert({op_name, 0}); - - // If the insert succeeded, the opcode has not been created already. Create a - // new operator code and update its index value in the map. - if (it.second) { - it.first->second = opcodes_.size(); - opcodes_.push_back(op_code); - } - return it.first->second; -} - -std::optional> -Translator::BuildOperator(Operation* inst, std::vector operands, - const std::vector& results) { - const auto* dialect = inst->getDialect(); - if (!dialect) { - inst->emitOpError("dialect is not registered"); - return std::nullopt; - } - - if (dialect == stablehlo_dialect_) { - auto op_code = GetOpCode(inst); - if (op_code == std::nullopt) { - return inst->emitOpError("op code not found"), std::nullopt; - } - - auto opcode_index = - GetOpcodeIndex(inst->getName().getStringRef().str(), op_code.value()); - std::optional> offset; - if (op_code == ::stablehlo::flatbuf::OperatorCode_REDUCE_WINDOW) { - offset = CreateFlatBufferOperator( - inst, opcode_index, operands, results, &builder_, - reduce_window_subgraph_map_ - [llvm::dyn_cast(inst)]); - } else { - offset = CreateFlatBufferOperator(inst, opcode_index, operands, results, - &builder_); - } - if (!offset) { - inst->emitOpError("is not a supported stablehlo op"); - } - return offset; - } - - return inst->emitOpError("a stableHLO op"), std::nullopt; -} - -} // namespace odml -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_translator.h b/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_translator.h deleted file mode 100644 index d9d1b7b0a17d81..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_translator.h +++ /dev/null @@ -1,176 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_SERIALIZER_FLATBUFFER_TRANSLATOR_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_SERIALIZER_FLATBUFFER_TRANSLATOR_H_ - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project -#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/lite/stablehlo/schema/schema_generated.h" -#include "tensorflow/lite/toco/toco_flags.pb.h" - -template -using BufferOffset = flatbuffers::Offset; - -template -using VectorBufferOffset = flatbuffers::Offset>; - -using CustomOptionsOffset = VectorBufferOffset; - -// Use initial buffer size in flatbuffer builder to be same as the initial size -// used by the TOCO export. (It does not explain rationale for this choice.) -// This number is currently inherited from Tflite -constexpr size_t kInitialBufferSize = 10240; - -namespace mlir { -namespace odml { - -// Translates an MLIR module in mhlo dialect to TFLite FlatBuffer. -class Translator { - public: - // Translates the given MLIR module into TFLite FlatBuffer format and returns - // the serialized output. Returns std::nullopt on unsupported, invalid inputs - // or internal error. - static std::optional Translate( - ModuleOp module, const toco::TocoFlags& toco_flags, - const std::unordered_set& tags, - tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper, - const std::map& metadata); - - private: - enum class OpType : char { kStablehloOp }; - explicit Translator(ModuleOp module, const toco::TocoFlags& toco_flags, - const std::unordered_set& saved_model_tags, - tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper, - const std::map& metadata) - : module_(module), - name_mapper_(*op_or_arg_name_mapper), - builder_(kInitialBufferSize), - saved_model_tags_(saved_model_tags) { - // The first buffer must be empty according to the schema definition. - empty_buffer_ = ::stablehlo::flatbuf::CreateBuffer(builder_); - buffers_.push_back(empty_buffer_); - stablehlo_dialect_ = - module.getContext() - ->getOrLoadDialect(); - // Right now the TF executor dialect is still needed to build NodeDef. - module.getContext() - ->getOrLoadDialect(); - } - - std::optional TranslateInternal(); - - // Returns TFLite buffer populated with constant value if the operation is - // TFLite constant operation. Otherwise, returns an empty buffer. Emits error - // and returns std::nullopt on failure. - std::optional> BuildBuffer( - Value value); - - // Builds TFLite tensor from the given value. `buffer_idx` is index of the - // corresponding buffer. Emits error and returns std::nullopt on failure. - std::optional> BuildTensor( - Value value, const std::string& name, unsigned buffer_idx); - - // Returns opcode index for op identified by the op_name, if already - // available. Otherwise, creates a new OperatorCode using the given `builtin` - // operator and associates it with `op_name`. - uint32_t GetOpcodeIndex(const std::string& op_name, - ::stablehlo::flatbuf::OperatorCode op_code); - - // Builds operator for the given operation with specified operand and result - // tensor indices. Emits an error and returns std::nullopt on failure. - std::optional> BuildOperator( - Operation* inst, std::vector operands, - const std::vector& results); - - // Build a subgraph with a given name out of the region either corresponding - // to a function's body or while op. Modifies *region by calling - // ExtractControlEdges. - std::optional> BuildSubGraph( - const std::string& name, Region* region, int index); - - // Uses the tf.entry_function attribute (if set) to initialize the op to name - // mapping. - void InitializeNamesFromAttribute(mlir::func::FuncOp fn, - bool* has_input_attr); - - // Returns a unique name for `val`. - std::string UniqueName(mlir::Value val); - - ModuleOp module_; - - tensorflow::OpOrArgNameMapper& name_mapper_; - - flatbuffers::FlatBufferBuilder builder_; - BufferOffset<::stablehlo::flatbuf::Buffer> empty_buffer_; - - std::vector> buffers_; - // Maps subgraph index and tensor name in the graph to the tensor index. - absl::flat_hash_map> - tensor_index_map_; - - // Maps op name to index of the corresponding OperatorCode in opcodes_ vector. - absl::flat_hash_map opcode_index_map_; - std::vector opcodes_; - - // Maps function name to index of the corresponding subgraph in the FlatBuffer - // model. - absl::flat_hash_map subgraph_index_map_; - absl::flat_hash_set enabled_op_types_; - - // maps between reduce_window op and their corresponding subgraphs - std::map reduce_window_subgraph_map_; - - // Points to stablehlo dialects & mhlo dialects, respectively. nullptr if the - // dialect is not registered. - Dialect* stablehlo_dialect_; - - // Set of saved model tags, if any. - const std::unordered_set saved_model_tags_; - // Map of key value pairs of metadata to export. - const std::map metadata_; - // A mapping table to mlir::Operation objects for TFL subgraph and operator - // index in a flatbuffer. - std::vector> subgraph_op_inst_map_; - - const std::string tf_entry_function_ = "tf.entry_function"; -}; - -} // namespace odml -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_SERIALIZER_FLATBUFFER_TRANSLATOR_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/tests/BUILD index 79cb17374fa940..dd691a25be14d9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir new file mode 100644 index 00000000000000..5924d0dce396c4 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir @@ -0,0 +1,34 @@ +// RUN: odml-to-stablehlo-opt -composite-lowering -verify-diagnostics %s | FileCheck %s + +func.func @hardswish(%arg0: tensor<2xf32>) -> (tensor<*xf32>) { + %0 = mhlo.composite "aten.hardswish.default" %arg0 {decomposition = @XlaCallModule_aten.hardswish.default.impl_0} : (tensor<2xf32>) -> tensor<2xf32> + %1 = "tf.Identity"(%0) {device = ""} : (tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> +} +func.func private @XlaCallModule_aten.hardswish.default.impl_0(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = mhlo.constant dense<6.000000e+00> : tensor + %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %2 = mhlo.constant dense<3.40282347E+38> : tensor + %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %4 = mhlo.constant dense<3.000000e+00> : tensor + %5 = "mhlo.broadcast_in_dim"(%4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %6 = mhlo.constant dense<0.000000e+00> : tensor + %7 = "mhlo.broadcast_in_dim"(%6) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %8 = mhlo.constant dense<-3.40282347E+38> : tensor + %9 = "mhlo.broadcast_in_dim"(%8) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %10 = mhlo.add %arg0, %5 : tensor<2xf32> + %11 = mhlo.clamp %7, %10, %3 : tensor<2xf32> + %12 = mhlo.clamp %9, %11, %1 : tensor<2xf32> + %13 = mhlo.multiply %arg0, %12 : tensor<2xf32> + %14 = mhlo.divide %13, %1 : tensor<2xf32> + return %14 : tensor<2xf32> +} + +// CHECK-LABEL: func.func @hardswish( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tfl.hard_swish"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: %[[VAL_2:.*]] = "tf.Identity"(%[[VAL_1]]) {device = ""} : (tensor<2xf32>) -> tensor<*xf32> +// CHECK: %[[VAL_3:.*]] = "tf.Identity"(%[[VAL_2]]) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_3]] : tensor<*xf32> +// CHECK: } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir deleted file mode 100644 index d0da1f09fa5ae1..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | flatbuffer_translate -mlir-to-tflite-flatbuffer - -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = stablehlo.add %arg0, %arg0 : tensor<2xi32> - %1 = stablehlo.subtract %0, %arg0 : tensor<2xi32> - func.return %1 : tensor<2xi32> -} -} - -// CHECK: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: %1 = "tfl.custom"(%0, %arg0) {custom_code = "stablehlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: return %1 : tensor<2xi32> -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-add.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-add.mlir deleted file mode 100644 index b0eb02192f4dad..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-add.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = stablehlo.add %arg0, %arg0 : tensor<2xi32> - func.return %0 : tensor<2xi32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: return %0 : tensor<2xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-broadcast_in_dim.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-broadcast_in_dim.mlir deleted file mode 100644 index 85653de898aa01..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-broadcast_in_dim.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { - %0= "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> - func.return %0 : tensor<1x2x2xi32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0) {custom_code = "stablehlo.broadcast_in_dim", custom_option = #tfl} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> -// CHECK-NEXT: return %0 : tensor<1x2x2xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-clamp.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-clamp.mlir deleted file mode 100644 index 2d0051afde986b..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-clamp.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = "stablehlo.clamp"(%arg0, %arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0, %arg0) {custom_code = "stablehlo.clamp", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: return %0 : tensor<2xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-compare.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-compare.mlir deleted file mode 100644 index 44b69ab933039f..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-compare.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xi1> { - %0 = stablehlo.compare LT, %arg0, %arg1 : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - %1 = stablehlo.compare LT, %arg0, %arg1, TOTALORDER : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - %2 = stablehlo.compare GT, %arg2, %arg3 : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - func.return %2 : tensor<2xi1> -} -} - -// CHECK: module { -// CHECK-NEXT: func.func @main(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xi1> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.compare", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK-NEXT: %1 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.compare", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK-NEXT: %2 = "tfl.custom"(%arg2, %arg3) {custom_code = "stablehlo.compare", custom_option = #tfl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: return %2 : tensor<2xi1> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir new file mode 100644 index 00000000000000..41a94b929c0f47 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir @@ -0,0 +1,37 @@ +// RUN: odml-to-stablehlo-opt %s -stablehlo-composite-legalize-tfl-custom | FileCheck %s +// RUN: tf_tfl_translate --enable-hlo-to-tf-conversion --input-mlir %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s --check-prefix=CHECK-ROUNDTRIP + +module { + func.func public @main(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<1x100x32x4xf32>, + %arg3: tensor<1x500x4x4xf32>, %arg4: tensor<1x500x4x4xf32>, %arg5: tensor<1x1x100x500xf32>, %arg6: tensor) + -> (tensor<3x3xf32>, tensor<1x100x32x4xf32>) { + // CHECK-ROUNDTRIP: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "odml.update_kv_cache", custom_option = #tfl} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK-ROUNDTRIP: %1 = "tfl.custom"(%arg2, %arg3, %arg4, %arg5, %arg6) {custom_code = "odml.scaled_dot_product_attention", custom_option = #tfl} : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> + %0 = func.call @test_kv_cache(%arg0, %arg1) : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + %1 = func.call @test_sdpa(%arg2, %arg3, %arg4, %arg5, %arg6) : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> + return %0, %1 : tensor<3x3xf32>, tensor<1x100x32x4xf32> + } + + // CHECK-LABEL: func.func private @test_kv_cache + func.func private @test_kv_cache(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> { + // CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "odml.update_kv_cache", custom_option = #tfl} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + %0 = stablehlo.composite "odml.update_kv_cache" %arg0, %arg1 {composite_attributes = {kv_cache_max = 500 : i64}, decomposition = @odml.update_kv_cache.impl} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> + } + func.func private @odml.update_kv_cache.impl(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> { + // No decomposition provided for test case. + return %arg0 : tensor<3x3xf32> + } + + // CHECK-LABEL: func.func private @test_sdpa + func.func private @test_sdpa(%arg0: tensor<1x100x32x4xf32>, %arg1: tensor<1x500x4x4xf32>, %arg2: tensor<1x500x4x4xf32>, %arg3: tensor<1x1x100x500xf32>, %arg4: tensor) -> tensor<1x100x32x4xf32> { + // CHECK: %0 = "tfl.custom"(%arg0, %arg1, %arg2, %arg3, %arg4) {custom_code = "odml.scaled_dot_product_attention", custom_option = #tfl} : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> + %0 = stablehlo.composite "odml.scaled_dot_product_attention" %arg0, %arg1, %arg2, %arg3, %arg4 {decomposition = @odml.scaled_dot_product_attention.impl} : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> + return %0 : tensor<1x100x32x4xf32> + } + func.func private @odml.scaled_dot_product_attention.impl(%arg0: tensor<1x100x32x4xf32>, %arg1: tensor<1x500x4x4xf32>, %arg2: tensor<1x500x4x4xf32>, %arg3: tensor<1x1x100x500xf32>, %arg4: tensor) -> tensor<1x100x32x4xf32> { + // No decomposition provided for test case. + return %arg0 : tensor<1x100x32x4xf32> + } + +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-concat.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-concat.mlir deleted file mode 100644 index 4be83175a417e1..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-concat.mlir +++ /dev/null @@ -1,18 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { - %1 = "stablehlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> - func.return %1 : tensor<6x3xf32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.concatenate", custom_option = #tfl} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> -// CHECK-NEXT: return %0 : tensor<6x3xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } - - - diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-constant.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-constant.mlir deleted file mode 100644 index 62c2253869c725..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-constant.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main() -> tensor<2xf32> { - %0 = stablehlo.constant dense<2> : tensor - %1 = stablehlo.constant dense<[10.0, 11.0]> : tensor<2xf32> - func.return %1 : tensor<2xf32> -} -} - -// CHECK: module { -// CHECK-NEXT: func.func @main() -> tensor<2xf32> { -// CHECK-NEXT: %0 = "tfl.custom"() {custom_code = "stablehlo.constant", custom_option = #tfl} : () -> tensor -// CHECK-NEXT: %1 = "tfl.custom"() {custom_code = "stablehlo.constant", custom_option = #tfl} : () -> tensor<2xf32> -// CHECK-NEXT: return %1 : tensor<2xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-conv.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-conv.mlir deleted file mode 100644 index aa7742c15e4c42..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-conv.mlir +++ /dev/null @@ -1,27 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck -dump-input always %s - -module { -func.func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> { - %0 = "stablehlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, feature_group_count = 1 : i64, lhs_dilation = array, padding = dense<1> : tensor<2x2xi64>, precision_config = [#stablehlo, #stablehlo], rhs_dilation = array, window_strides = array, window_reversal = array} : - (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> - func.return %0 : tensor<16x8x8x1xf32> -} -} - -// CHECK: module { -// CHECK-NEXT: func.func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.convolution", custom_option = #tfl} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> -// CHECK-NEXT: return %0 : tensor<16x8x8x1xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-dot.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-dot.mlir deleted file mode 100644 index ef715f778e8292..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-dot.mlir +++ /dev/null @@ -1,22 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<72x2048xf32>, %arg1: tensor<2048x512xf32>) -> tensor<72x512xf32> { - %0 = "stablehlo.dot"(%arg0, %arg1) { - dimension_numbers = #stablehlo.dot< - lhs_batching_dimensions = [0, 1], - rhs_batching_dimensions = [1, 2], - lhs_contracting_dimensions = [0, 1], - rhs_contracting_dimensions = [1, 2] - >} : - (tensor<72x2048xf32>, tensor<2048x512xf32>) -> tensor<72x512xf32> - func.return %0 : tensor<72x512xf32> -} -} - -// CHECK: module { -// CHECK-NEXT: func.func @main(%arg0: tensor<72x2048xf32>, %arg1: tensor<2048x512xf32>) -> tensor<72x512xf32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.dot", custom_option = #tfl} : (tensor<72x2048xf32>, tensor<2048x512xf32>) -> tensor<72x512xf32> -// CHECK-NEXT: return %0 : tensor<72x512xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-gather.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-gather.mlir deleted file mode 100644 index 47c716c0ca5243..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-gather.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<1x128x256xf32>, %arg1: tensor<30x1x2xi32>) -> tensor<30x1x256xf32> { - %0 = "stablehlo.gather"(%arg0, %arg1) { - dimension_numbers = #stablehlo.gather< - offset_dims = [2], - collapsed_slice_dims = [0, 1], - start_index_map = [0, 1], - index_vector_dim = 2>, - indices_are_sorted = false, - slice_sizes = array} : - (tensor<1x128x256xf32>, tensor<30x1x2xi32>) -> tensor<30x1x256xf32> - func.return %0 : tensor<30x1x256xf32> -} -} - -// CHECK: module { -// CHECK-NEXT: func.func @main(%arg0: tensor<1x128x256xf32>, %arg1: tensor<30x1x2xi32>) -> tensor<30x1x256xf32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.gather", custom_option = #tfl} : (tensor<1x128x256xf32>, tensor<30x1x2xi32>) -> tensor<30x1x256xf32> -// CHECK-NEXT: return %0 : tensor<30x1x256xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-max.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-max.mlir deleted file mode 100644 index e8ccfcaee07805..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-max.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = stablehlo.maximum %arg0, %arg0 : tensor<2xi32> - func.return %0 : tensor<2xi32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.maximum", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: return %0 : tensor<2xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-mul.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-mul.mlir deleted file mode 100644 index b4bcbc455f2d24..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-mul.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = stablehlo.multiply %arg0, %arg0 : tensor<2xi32> - func.return %0 : tensor<2xi32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.multiply", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: return %0 : tensor<2xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-pad.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-pad.mlir deleted file mode 100644 index bffb1da2b07117..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-pad.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<8x128xf32>, %arg1: tensor) -> tensor<11x131xf32> { - %0 = "stablehlo.pad"(%arg0, %arg1) { - edge_padding_low = array, - edge_padding_high = array, - interior_padding = array - } : (tensor<8x128xf32>, tensor) -> tensor<11x131xf32> - func.return %0 : tensor<11x131xf32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<8x128xf32>, %arg1: tensor) -> tensor<11x131xf32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.pad", custom_option = #tfl} : (tensor<8x128xf32>, tensor) -> tensor<11x131xf32> -// CHECK-NEXT: return %0 : tensor<11x131xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-reshape.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-reshape.mlir deleted file mode 100644 index 281f14bf8b844e..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-reshape.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<2xi32>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0) {custom_code = "stablehlo.reshape", custom_option = #tfl} : (tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: return %0 : tensor<2xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-rsqrt.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-rsqrt.mlir deleted file mode 100644 index f352e19959cba1..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-rsqrt.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "stablehlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0) {custom_code = "stablehlo.rsqrt", custom_option = #tfl} : (tensor<2xf32>) -> tensor<2xf32> -// CHECK-NEXT: return %0 : tensor<2xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-scatter.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-scatter.mlir deleted file mode 100644 index 5bd79227f576b8..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-scatter.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1xi32>) -> tensor<3xi32> { - %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - "stablehlo.return"(%arg4) : (tensor) -> () - }) { - scatter_dimension_numbers = #stablehlo.scatter< - update_window_dims = [], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1>, - indices_are_sorted = false, - unique_indices = false} : - (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32> - func.return %0 : tensor<3xi32> -} -} - -// CHECK: module { -// CHECK-NEXT: func.func @main(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1xi32>) -> tensor<3xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "stablehlo.scatter", custom_option = #tfl} : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32> -// CHECK-NEXT: return %0 : tensor<3xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-sub.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-sub.mlir deleted file mode 100644 index bc4f72fd2bcd48..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-sub.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = stablehlo.subtract %arg0, %arg0 : tensor<2xi32> - func.return %0 : tensor<2xi32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: return %0 : tensor<2xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl.mlir deleted file mode 100644 index 8898fac4288218..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = stablehlo.add %arg0, %arg0 : tensor<2xi32> - %1 = stablehlo.subtract %0, %arg0 : tensor<2xi32> - func.return %1 : tensor<2xi32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: %1 = "tfl.custom"(%0, %arg0) {custom_code = "stablehlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: return %1 : tensor<2xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-smuggle-resize.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-smuggle-resize.mlir index ec8ab139054e63..4a0f6a5d5e673b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-smuggle-resize.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-smuggle-resize.mlir @@ -1,10 +1,12 @@ // RUN: odml_to_stablehlo %s -skip-resize -smuggle-disallowed-ops -o - | FileCheck %s +// RUN: odml-to-stablehlo-opt %s --smuggle-disallowed-ops-pass | FileCheck %s --check-prefix=CHECK-OPT // CHECK-LABEL: @main module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 975 : i32}, tf_saved_model.semantics} { func.func @serving_default(%arg0: tensor<1x32x32x128xf32> {tf_saved_model.index_path = ["a"]}) -> (tensor<1x64x64x128xf32> {tf_saved_model.index_path = ["b"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "c:0", outputs = "d:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = "tf.Const"() {value = dense<[56, 904]> : tensor<2xi32>} : () -> tensor<2xi32> // CHECK: %1 = stablehlo.custom_call @tf.ResizeBilinear(%arg0, %0) {align_corners = false, device = "", half_pixel_centers = true} : (tensor<1x32x32x128xf32>, tensor<2xi32>) -> tensor<1x64x64x128xf32> + // CHECK-OPT: %0 = stablehlo.custom_call @tf.ResizeBilinear(%arg0, %cst) {align_corners = false, device = "", half_pixel_centers = true} : (tensor<1x32x32x128xf32>, tensor<2xi32>) -> tensor<1x64x64x128xf32> %1 = "tf.ResizeBilinear"(%arg0, %0) { align_corners = false, device = "", half_pixel_centers = true } : (tensor<1x32x32x128xf32>, tensor<2xi32>) -> tensor<1x64x64x128xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/stablehlo-custom-call-legalize-composite.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/stablehlo-custom-call-legalize-composite.mlir new file mode 100644 index 00000000000000..b2b12c4c47b579 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/stablehlo-custom-call-legalize-composite.mlir @@ -0,0 +1,18 @@ +// RUN: odml-to-stablehlo-opt %s -stablehlo-custom-call-legalize-composite | FileCheck %s + +// CHECK-LABEL: module +module { + // CHECK-LABEL: @main + func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) { + // CHECK: stablehlo.custom_call @foo + stablehlo.custom_call @foo() : () -> () + // CHECK-NOT: stablehlo.custom_call + // CHECK: stablehlo.composite "odml.foo" %arg0, %arg1 {composite_attributes = {bar = 500 : i64}, decomposition = @foo.impl} : (tensor<1xf32>, tensor<2xf32>) -> (tensor<2xf32>, tensor<1xf32>) + %1:2 = stablehlo.custom_call @stablehlo.composite(%arg0, %arg1) {called_computations = [@foo.impl], composite.backend_config = {attributes = {bar = 500 : i64}, name = "odml.foo"}} : (tensor<1xf32>, tensor<2xf32>) -> (tensor<2xf32>, tensor<1xf32>) + return + } + // CHECK-LABEL: func private @foo.impl + func.func private @foo.impl(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> (tensor<2xf32>, tensor<1xf32>) { + return %arg1, %arg0 : tensor<2xf32>, tensor<1xf32> + } +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir index 9a9ea66195f7cb..7107f7dcb08a45 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir @@ -186,12 +186,10 @@ func.func @convolution_upstream_srq_strides(%arg0: tensor<1x3x3x4x!quant.uniform } // CHECK-LABEL: convolution_upstream_srq_strides // CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> -// CHECK-DAG: %[[CONST_0:.+]] = "tfl.pseudo_const"() {value = dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32>} : () -> tensor<4x2xi32> // CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> // CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> -// CHECK: %[[PAD:.+]] = "tfl.pad"(%[[ARG]], %[[CONST_0]]) : (tensor<1x3x3x4x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x5x5x4x!quant.uniform> // Tests that the stride_w is set to 2. -// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<1x5x5x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> +// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> // CHECK: return %[[CONV2D]] : tensor<1x3x2x2x!quant.uniform> // ----- @@ -766,7 +764,7 @@ func.func @conv_with_bias_same_padding_srq_depthwise(%arg0: tensor<1x4x5x3x!quan // ----- -// Tests that a quantized stablehlo.transpose is converted to tfl.transpose. +// Tests that a quantized `stablehlo.transpose` is converted to `tfl.transpose`. func.func @transpose( %arg0: tensor<2x3x4x!quant.uniform> @@ -783,19 +781,19 @@ func.func @transpose( // ----- -// Tests that a float stablehlo.transpose is not converted to tfl.transpose. +// Tests that a float `stablehlo.transpose` is not converted to `tfl.transpose`. -func.func @float_transpose(%arg0: tensor<2x3x4xf32>) -> tensor<4x3x2xf32> { +func.func @transpose_float(%arg0: tensor<2x3x4xf32>) -> tensor<4x3x2xf32> { %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<2x3x4xf32>) -> tensor<4x3x2xf32> return %0 : tensor<4x3x2xf32> } -// CHECK-LABEL: float_transpose +// CHECK-LABEL: transpose_float // CHECK-NOT: tfl.transpose // CHECK: stablehlo.transpose // ----- -// Tests that a quantized stablehlo.reshape is converted to tfl.reshape. +// Tests that a quantized `stablehlo.reshape` is converted to `tfl.reshape`. func.func @reshape( %arg0: tensor<2x3x4x!quant.uniform> @@ -812,19 +810,19 @@ func.func @reshape( // ----- -// Tests that a float stablehlo.reshape is not converted to tfl.reshape. +// Tests that a float `stablehlo.reshape` is not converted to `tfl.reshape`. -func.func @float_reshape(%arg0: tensor<2x3x4xf32>) -> tensor<6x4xf32> { +func.func @reshape_float(%arg0: tensor<2x3x4xf32>) -> tensor<6x4xf32> { %0 = stablehlo.reshape %arg0 : (tensor<2x3x4xf32>) -> tensor<6x4xf32> return %0 : tensor<6x4xf32> } -// CHECK-LABEL: float_reshape +// CHECK-LABEL: reshape_float // CHECK-NOT: tfl.reshape // CHECK: stablehlo.reshape // ----- -// Tests that a quantized stablehlo.select is converted to tfl.select_v2. +// Tests that a quantized `stablehlo.select` is converted to `tfl.select_v2`. func.func @select( %arg0: tensor<1x3xi1>, @@ -846,19 +844,20 @@ func.func @select( // ----- -// Tests that a float stablehlo.select is not converted to tfl.select_v2. +// Tests that a float `stablehlo.select` is not converted to `tfl.select_v2`. -func.func @float_select(%arg0: tensor<1x3xi1>, %arg1: tensor<1x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> { +func.func @select_float(%arg0: tensor<1x3xi1>, %arg1: tensor<1x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> { %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor<1x3xi1>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> return %0 : tensor<1x3xf32> } -// CHECK-LABEL: float_select +// CHECK-LABEL: select_float // CHECK-NOT: tfl.select_v2 // CHECK: stablehlo.select // ----- -// Tests that a quantized stablehlo.concatenate is converted to tfl.concatenation. +// Tests that a quantized `stablehlo.concatenate` is converted to +// `tfl.concatenation`. func.func @concatenate( %arg0: tensor<3x2x!quant.uniform>, @@ -878,20 +877,21 @@ func.func @concatenate( // ----- -// Tests that a float stablehlo.concatenate is not converted to tfl.concatenation. +// Tests that a float `stablehlo.concatenate` is not converted to +// `tfl.concatenation`. -func.func @float_concatenate(%arg0: tensor<3x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<4x2xf32> { +func.func @concatenate_float(%arg0: tensor<3x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<4x2xf32> { %0 = "stablehlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x2xf32>, tensor<1x2xf32>) -> tensor<4x2xf32> return %0 : tensor<4x2xf32> } -// CHECK-LABEL: float_concatenate +// CHECK-LABEL: concatenate_float // CHECK-NOT: tfl.concatenation // CHECK: stablehlo.concatenate // ----- -// Tests that a quantized stablehlo.pad without interior padding is converted to -// tfl.padv2. +// Tests that a quantized `stablehlo.pad` without interior padding is +// converted to `tfl.padv2`. func.func @pad_without_interior_padding( %arg0: tensor<2x3x!quant.uniform>, @@ -913,8 +913,8 @@ func.func @pad_without_interior_padding( // ----- -// Tests that a quantized stablehlo.pad with interior padding is converted to -// tfl.dilate and tfl.padv2. +// Tests that a quantized `stablehlo.pad` with interior padding is converted to +// `tfl.dilate` and `tfl.padv2`. func.func @pad_with_interior_padding( %arg0: tensor<2x3x!quant.uniform>, @@ -939,20 +939,20 @@ func.func @pad_with_interior_padding( // ----- -// Tests that a float stablehlo.pad is not converted to tfl.padv2. +// Tests that a float `stablehlo.pad` is not converted to `tfl.padv2`. -func.func @float_pad(%arg0: tensor<2x3xf32>, %arg1: tensor) -> tensor<4x5xf32> { +func.func @pad_float(%arg0: tensor<2x3xf32>, %arg1: tensor) -> tensor<4x5xf32> { %0 = stablehlo.pad %arg0, %arg1, low = [0, 1], high = [2, 1], interior = [0, 0] : (tensor<2x3xf32>, tensor) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } -// CHECK-LABEL: float_pad +// CHECK-LABEL: pad_float // CHECK-NOT: tfl.padv2 // CHECK: stablehlo.pad // ----- -// Tests that a quantized stablehlo.slice is converted to tfl.slice when stride -// is 1. +// Tests that a quantized `stablehlo.slice` is converted to +// `tfl.slice` when stride is 1. func.func @slice( %arg0: tensor<3x4x!quant.uniform> @@ -975,8 +975,8 @@ func.func @slice( // ----- -// Tests that a quantized stablehlo.slice is converted to tfl.strided_slice when -// stride is not 1. +// Tests that a quantized `stablehlo.slice` is converted to `tfl.strided_slice` +// when stride is not 1. func.func @strided_slice( %arg0: tensor<3x6x!quant.uniform> @@ -1003,9 +1003,9 @@ func.func @strided_slice( // ----- -// Tests that a float stablehlo.slice is not converted to tfl.slice. +// Tests that a float `stablehlo.slice` is not converted to `tfl.slice`. -func.func @float_slice(%arg0: tensor<3x4xf32>) -> tensor<2x2xf32> { +func.func @slice_float(%arg0: tensor<3x4xf32>) -> tensor<2x2xf32> { %0 = "stablehlo.slice"(%arg0) { start_indices = array, limit_indices = array, @@ -1013,15 +1013,15 @@ func.func @float_slice(%arg0: tensor<3x4xf32>) -> tensor<2x2xf32> { } : (tensor<3x4xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } -// CHECK-LABEL: float_slice +// CHECK-LABEL: slice_float // CHECK-NOT: tfl.slice // CHECK-NOT: tfl.strided_slice // CHECK: stablehlo.slice // ----- -// Tests that a quantized stablehlo.broadcast_in_dim is converted to -// tfl.broadcast_to. +// Tests that a quantized `stablehlo.broadcast_in_dim` is converted to +// `tfl.broadcast_to`. func.func @broadcast_in_dim( %arg0: tensor<1x2x!quant.uniform> @@ -1040,8 +1040,8 @@ func.func @broadcast_in_dim( // ----- -// Tests that a quantized stablehlo.broadcast_in_dim is converted to -// tfl.transpose and tfl.broadcast_to when broadcast_dimensions is not in +// Tests that a quantized `stablehlo.broadcast_in_dim` is converted to +// `tfl.transpose` and `tfl.broadcast_to` when `broadcast_dimensions` is not in // ascending order. func.func @broadcast_in_dim_with_transpose( @@ -1064,8 +1064,8 @@ func.func @broadcast_in_dim_with_transpose( // ----- -// Tests that a quantized stablehlo.broadcast_in_dim is converted to -// tfl.expand_dims and tfl.broadcast_to when input rank is smaller than output +// Tests that a quantized `stablehlo.broadcast_in_dim` is converted to +// tfl.expand_dims and `tfl.broadcast_to` when input rank is smaller than output // rank. func.func @broadcast_in_dim_with_expand( @@ -1088,9 +1088,10 @@ func.func @broadcast_in_dim_with_expand( // ----- -// Tests that a quantized stablehlo.broadcast_in_dim is converted to -// tfl.transpose, tfl.expand_dims and tfl.broadcast_to when broadcast_dimensions -// is not in ascending order and input rank is smaller than output rank. +// Tests that a quantized `stablehlo.broadcast_in_dim` is converted to +// `tfl.transpose`, `tfl.expand_dims` and `tfl.broadcast_to` when +// `broadcast_dimensions` is not in ascending order and input rank is smaller +// than output rank. func.func @broadcast_in_dim_with_transpose_and_expand( %arg0: tensor<2x3x4x!quant.uniform> @@ -1114,15 +1115,16 @@ func.func @broadcast_in_dim_with_transpose_and_expand( // ----- -// Tests that a float stablehlo.broadcast_in_dim is not converted to tfl.broadcast_to. +// Tests that a float `stablehlo.broadcast_in_dim` is not converted to +// `tfl.broadcast_to`. -func.func @float_broadcast_in_dim(%arg0: tensor<1x2xf32>) -> tensor<3x2xf32> { +func.func @broadcast_in_dim_float(%arg0: tensor<1x2xf32>) -> tensor<3x2xf32> { %0 = "stablehlo.broadcast_in_dim"(%arg0) { broadcast_dimensions = array } : (tensor<1x2xf32>) -> tensor<3x2xf32> return %0 : tensor<3x2xf32> } -// CHECK-LABEL: float_broadcast_in_dim +// CHECK-LABEL: broadcast_in_dim_float // CHECK-NOT: tfl.broadcast_to // CHECK-NOT: tfl.transpose // CHECK-NOT: tfl.expand_dims @@ -1130,8 +1132,8 @@ func.func @float_broadcast_in_dim(%arg0: tensor<1x2xf32>) -> tensor<3x2xf32> { // ----- -// Test that a quantized stablehlo.reduce_window with max is converted to -// tfl.max_pool_2d. +// Tests that a quantized `stablehlo.reduce_window` with max is converted to +// `tfl.max_pool_2d`. func.func @reduce_window_with_max( %arg0: tensor<2x9x10x3x!quant.uniform>, @@ -1155,8 +1157,8 @@ func.func @reduce_window_with_max( // ----- -// Test that a quantized stablehlo.reduce_window with max whose rank is not 4 -// is not converted to tfl.max_pool_2d. +// Tests that a quantized `stablehlo.reduce_window `with max whose rank is not 4 +// is not converted to `tfl.max_pool_2d`. func.func @reduce_window_not_4d( %arg0: tensor<3x2x9x10x3x!quant.uniform>, @@ -1176,8 +1178,8 @@ func.func @reduce_window_not_4d( // ----- -// Test that a quantized stablehlo.reduce_window with max that takes multiple -// inputs is not converted to tfl.max_pool_2d. +// Tests that a quantized `stablehlo.reduce_window` with max that takes multiple +// inputs is not converted to `tfl.max_pool_2d`. func.func @reduce_window_not_binary( %arg0: tensor<3x2x9x10x3x!quant.uniform>, @@ -1200,10 +1202,10 @@ func.func @reduce_window_not_binary( // ----- -// Test that a float stablehlo.reduce_window with max is not converted to -// tfl.max_pool_2d. +// Tests that a float `stablehlo.reduce_window` with max is not converted to +// `tfl.max_pool_2d`. -func.func @float_reduce_window_with_max( +func.func @reduce_window_with_max_float( %arg0: tensor<2x9x10x3xf32>, %arg1: tensor ) -> tensor<2x4x3x3xf32> { @@ -1215,13 +1217,14 @@ func.func @float_reduce_window_with_max( return %0 : tensor<2x4x3x3xf32> } -// CHECK-LABEL: float_reduce_window_with_max +// CHECK-LABEL: reduce_window_with_max_float // CHECK: stablehlo.reduce_window // CHECK-NOT: tfl.max_pool_2d // ----- -// Test that a quantized stablehlo.dynamic_reshape is converted to tfl.reshape. +// Tests that a quantized `stablehlo.dynamic_reshape` is converted to +// `tfl.reshape`. func.func @dynamic_reshape( %arg0: tensor>, @@ -1242,20 +1245,21 @@ func.func @dynamic_reshape( // ----- -// Test that a float stablehlo.dynamic_reshape is not converted to tfl.reshape. +// Tests that a float `stablehlo.dynamic_reshape` is not converted to +// `tfl.reshape`. -func.func @float_dynamic_reshape(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor { +func.func @dynamic_reshape_float(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor { %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor, tensor<2xi32>) -> tensor return %0 : tensor } -// CHECK-LABEL: func @float_dynamic_reshape +// CHECK-LABEL: func @dynamic_reshape_float // CHECK: stablehlo.dynamic_reshape // CHECK-NOT: tfl.reshape // ----- -// Test that a quantized stablehlo.gather is converted to tfl.gather_nd. +// Tests that a quantized `stablehlo.gather` is converted to tfl.gather_nd. func.func @gather( %arg0: tensor<3x4x2x2x!quant.uniform>, @@ -1284,8 +1288,8 @@ func.func @gather( // ----- -// Test that a quantized stablehlo.gather with unsorted start_index_map is not -// converted to tfl.gather_nd (condition 1 is not satisfied). +// Tests that a quantized `stablehlo.gather` with unsorted start_index_map is +// not converted to `tfl.gather_nd` (condition 1 is not satisfied). func.func @gather_start_index_map_not_sorted( %arg0: tensor<3x4x2x2x!quant.uniform>, @@ -1313,7 +1317,7 @@ func.func @gather_start_index_map_not_sorted( // ----- -// Test that a quantized stablehlo.gather is not converted to tfl.gather_nd +// Tests that a quantized `stablehlo.gather` is not converted to tfl.gather_nd // when index_vector_dim is not the last dimension of start_indices (condition 2 // is not satisfied). @@ -1343,7 +1347,7 @@ func.func @gather_start_index_vector_dim_not_at_last( // ----- -// Test that a quantized stablehlo.gather is not converted to tfl.gather_nd +// Tests that a quantized `stablehlo.gather` is not converted to tfl.gather_nd // when offset_dims are not the last dimensions of the output (condition 3 is // not satisfied). @@ -1373,7 +1377,7 @@ func.func @gather_offset_dims_not_at_last( // ----- -// Test that a quantized stablehlo.gather is not converted to tfl.gather_nd +// Tests that a quantized `stablehlo.gather` is not converted to tfl.gather_nd // when shape of slice is not same with shape of offset (condition 4 is not // satisfied). @@ -1403,9 +1407,9 @@ func.func @gather_different_slice_and_offset( // ----- -// Test that a float stablehlo.gather is not converted to tfl.gather_nd. +// Tests that a float `stablehlo.gather` is not converted to `tfl.gather_nd`. -func.func @float_gather(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> { +func.func @gather_float(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> { %0 = "stablehlo.gather"(%arg0, %arg1) { dimension_numbers = #stablehlo.gather< offset_dims = [2, 3], @@ -1418,7 +1422,161 @@ func.func @float_gather(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2x3x2xi64>) -> return %0 : tensor<2x3x2x2xf32> } -// CHECK-LABEL: func @float_gather +// CHECK-LABEL: func @gather_float // CHECK: stablehlo.gather // CHECK-NOT: tfl.gather_nd // CHECK-NOT: tfl.gather + +// ----- + +// Tests that a quantized `stablehlo.dynamic_slice` is converted to `tfl.slice`. + +// CHECK-LABEL: func @dynamic_slice +// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4x!quant.uniform>, %[[ARG1:.+]]: tensor, %[[ARG2:.+]]: tensor +func.func @dynamic_slice( + %arg0: tensor<4x4x!quant.uniform>, + %arg1: tensor, + %arg2: tensor + ) -> tensor<2x1x!quant.uniform> { + %0 = "stablehlo.dynamic_slice"(%arg0, %arg1, %arg2) { + slice_sizes = array + } : ( + tensor<4x4x!quant.uniform>, tensor, + tensor + ) -> tensor<2x1x!quant.uniform> + return %0 : tensor<2x1x!quant.uniform> +} + + +// CHECK-DAG: %[[SLICE_SIZE:.+]] = arith.constant dense<[2, 1]> : tensor<2xi64> +// CHECK-DAG: %[[ZERO:.+]] = arith.constant dense<0> : tensor<1xi64> +// CHECK-DAG: %[[MAX1:.+]] = arith.constant dense<2> : tensor<1xi64> +// CHECK-DAG: %[[MAX2:.+]] = arith.constant dense<3> : tensor<1xi64> +// CHECK: %[[BITCAST1:.+]] = "tfl.bitcast"(%[[ARG1]]) : (tensor) -> tensor<1xi64> +// CHECK: %[[MIN1:.+]] = "tfl.minimum"(%[[BITCAST1]], %[[MAX1]]) : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> +// CHECK: %[[BITCAST2:.+]] = "tfl.bitcast"(%[[ARG2]]) : (tensor) -> tensor<1xi64> +// CHECK: %[[MIN2:.+]] = "tfl.minimum"(%[[BITCAST2]], %[[MAX2]]) : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> +// CHECK: %[[CONCAT:.+]] = "tfl.concatenation"(%[[MIN1]], %[[MIN2]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> +// CHECK: %[[MAX:.+]] = "tfl.maximum"(%[[CONCAT]], %[[ZERO]]) : (tensor<2xi64>, tensor<1xi64>) -> tensor<2xi64> +// CHECK: %[[SLICE:.+]] = "tfl.slice"(%[[ARG0]], %[[MAX]], %[[SLICE_SIZE]]) +// CHECK-SAME: (tensor<4x4x!quant.uniform>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x1x!quant.uniform> + +// ----- + +// Tests that a float `stablehlo.dynamic_slice` is not converted to `tfl.slice`. + +func.func @dynamic_slice_float(%arg0: tensor<4x4xf32>, %arg1: tensor, %arg2: tensor) -> tensor<2x1xf32> { + %0 = "stablehlo.dynamic_slice"(%arg0, %arg1, %arg2) { + slice_sizes = array + } : (tensor<4x4xf32>, tensor, tensor) -> tensor<2x1xf32> + return %0 : tensor<2x1xf32> +} + +// CHECK-LABEL: func @dynamic_slice_float +// CHECK: stablehlo.dynamic_slice +// CHECK-NOT: tfl.bitcast +// CHECK-NOT: tfl.minimum +// CHECK-NOT: tfl.maximum +// CHECK-NOT: tfl.slice + +// ----- + +// Tests that `stablehlo.add` with both operands int8 UniformQuantizedType is +// properly converted into `tfl.add`. + +func.func @add(%arg0: tensor<1x3x!quant.uniform>, %arg1: tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> { + %0 = stablehlo.add %arg0, %arg1 : (tensor<1x3x!quant.uniform>, tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> +} + +// CHECK-LABEL: func @add +// CHECK: %[[ADD:.+]] = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x3x!quant.uniform>, tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: return %[[ADD]] + +// ----- + +// Tests that `stablehlo.add` with int32 UniformQuantizedPerAxisTypes is +// not converted. + +func.func @add_i32(%arg0: tensor<1x3x!quant.uniform>, %arg1: tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> { + %0 = stablehlo.add %arg0, %arg1 : (tensor<1x3x!quant.uniform>, tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> +} + +// CHECK-LABEL: func @add_i32 +// CHECK: stablehlo.add +// CHECK-NOT: tfl.add + +// ----- + +// Tests that a quantized `stablehlo.constant` is converted into `tfl.qconst`. + +// CHECK-LABEL: func @quantized_constant +func.func @quantized_constant() -> tensor<1x2x4x5x!quant.uniform> { + %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> + return %0 : tensor<1x2x4x5x!quant.uniform> +} + +// CHECK: %[[QCONST:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} +// CHECK-SAME: () -> tensor<1x2x4x5x!quant.uniform> +// CHECK: return %[[QCONST]] + +// ----- + +// Tests that a float `stablehlo.constant` is not converted into `tfl.qconst`. + +// CHECK-LABEL: func @float_constant +func.func @float_constant() -> tensor<1x2x4x5xf32> { + %0 = stablehlo.constant() {value = dense<1.0> : tensor<1x2x4x5xf32>} : () -> tensor<1x2x4x5xf32> + return %0 : tensor<1x2x4x5xf32> +} + +// CHECK: stablehlo.constant +// CHECK-NOT: tfl.pseudo_qconst +// CHECK-NOT: tfl.pseudo_const +// CHECK-NOT: arith.constant + +// ----- + +// Tests that a hybrid quantized dot_general is splitted into dequantize and float +// dot_general. + +// CHECK-LABEL: func @dot_general_hybrid +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x3x4xf32> +func.func @dot_general_hybrid(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x5xf32> { + %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> + %1 = "stablehlo.dot_general"(%arg0, %0) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0, 1], + rhs_batching_dimensions = [0, 1], + lhs_contracting_dimensions = [3], + rhs_contracting_dimensions = [2]>, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x2x3x4xf32>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5xf32> + return %1 : tensor<1x2x3x5xf32> +} + +// CHECK: %[[WEIGHT:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} +// CHECK: %[[DQ:.+]] = "tfl.dequantize"(%[[WEIGHT]]) : (tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x4x5xf32> +// CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG0]], %[[DQ]], batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2], precision = [DEFAULT, DEFAULT] : (tensor<1x2x3x4xf32>, tensor<1x2x4x5xf32>) -> tensor<1x2x3x5xf32> +// CHECK: return %[[DOT]] + +// ----- + +// Tests that a hybrid quantized convolution is splitted into dequantize and +// float convolution. + +// CHECK-LABEL: func @convolution_hybrid +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x3x4xf32> +func.func @convolution_hybrid(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x2xf32> { + %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2xf32> + return %1 : tensor<1x3x3x2xf32> +} + +// CHECK: %[[WEIGHT:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x3x4x2x!quant.uniform>, value = dense<3> : tensor<3x3x4x2xi8>} +// CHECK: %[[DQ:.+]] = "tfl.dequantize"(%[[WEIGHT]]) : (tensor<3x3x4x2x!quant.uniform>) -> tensor<3x3x4x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG0]], %[[DQ]]) +// CHECK{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} +// CHECK-SAME: (tensor<1x3x3x4xf32>, tensor<3x3x4x2xf32>) -> tensor<1x3x3x2xf32> +// CHECK: return %[[CONV]] diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.cc new file mode 100644 index 00000000000000..0dc354f998d246 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.cc @@ -0,0 +1,79 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep + +namespace mlir { +namespace odml { + +namespace { + +// This file is generated from `passes.td` and provides the implementation base +// class. +#define GEN_PASS_DEF_COMPOSITELOWERINGPASS +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" + +class CompositeLoweringPass + : public impl::CompositeLoweringPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CompositeLoweringPass); + + void runOnOperation() override; +}; + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/generated_composite_lowering.inc" + +void CompositeLoweringPass::runOnOperation() { + MLIRContext& context = getContext(); + RewritePatternSet patterns(&getContext()); + + populateWithGenerated(patterns); + + ConversionTarget target(context); + target.addLegalDialect(); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + getOperation().emitError("Composite lowering pass failed."); + signalPassFailure(); + } +} + +} // namespace + +// Creates an instance of the pass. +std::unique_ptr> CreateCompositeLoweringPass() { + return std::make_unique(); +} + +// Registers the pass implementation +static PassRegistration pass; + +} // namespace odml +} // namespace mlir diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.h similarity index 57% rename from third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.h rename to tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.h index 2ac358b4ee56c5..0bb758ad9f154b 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.h @@ -1,4 +1,4 @@ -/* Copyright 2024 The OpenXLA Authors. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ -#define XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_LOWERING_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_LOWERING_PASS_H_ -namespace stream_executor::gpu { -enum struct GpuSemaphoreState { Hold, Release, TimedOut }; -namespace delay_kernel { -void* kernel(); // returns a pointer to a CUDA C++ device function -} // namespace delay_kernel -} // namespace stream_executor::gpu +namespace mlir { +namespace odml { -#endif // XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ +std::unique_ptr CreateCompositeLoweringPass(); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_LOWERING_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td new file mode 100644 index 00000000000000..1b62b6fcc4aeae --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td @@ -0,0 +1,28 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Pattern definition file for direct lowering of mhlo composites to tflite ops. + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mhlo/IR/hlo_ops.td" +include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" + + +def LegalizeHardSwishComposite: Pat< + (MHLO_CompositeOp:$old_value + (variadic $input), + ConstantStrAttr, $_, $_, $_), + (TFL_HardSwishOp $input)>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc index 066bc83ad90217..847738e5cc7cbe 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc @@ -177,8 +177,7 @@ class FoldBroadcastInDimBeforeBinaryElementwiseOp // When the operand other than the broadcast op is not a const op, we // should not fold broadcast op. auto binary_op_const_operand = - lhs_bcast_op ? rhs.template getDefiningOp() - : lhs.template getDefiningOp(); + (lhs_bcast_op ? rhs : lhs).template getDefiningOp(); if (!binary_op_const_operand) return failure(); auto bcast_op = lhs_bcast_op ? lhs_bcast_op : rhs_bcast_op; auto const_op = diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc new file mode 100644 index 00000000000000..a35f5ba324e3f4 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc @@ -0,0 +1,137 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" + +#define DEBUG_TYPE "composite-to-custom" + +namespace mlir { +namespace odml { + +#define GEN_PASS_DEF_LEGALIZECOMPOSITETOCUSTOMOPPASS +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" + +namespace { +bool IsSupportedComposite(::mlir::stablehlo::CompositeOp op) { + // List of supported composites to represent using CustomOp. + return llvm::is_contained( + {"odml.update_kv_cache", "odml.scaled_dot_product_attention"}, + op.getName()); +} + +TFL::ConstBytesAttr CustomOption(OpBuilder* builder, + const std::string& content) { + return TFL::ConstBytesAttr::get(builder->getContext(), + StringRef(content.data(), content.size())); +} + +LogicalResult BuildOption(flexbuffers::Builder* fbb, Operation* op, + NamedAttribute pair) { + const char* key = pair.getName().data(); + const auto attr = pair.getValue(); + + if (attr.isa<::mlir::IntegerAttr>()) { + fbb->Int(key, attr.dyn_cast().getInt()); + return success(); + } + + if (attr.isa<::mlir::FloatAttr>()) { + fbb->Double(key, attr.dyn_cast().getValueAsDouble()); + return success(); + } + + return op->emitWarning("serialization not supported for : ") << key; +} + +TFL::CustomOp BuildCustomOp(stablehlo::CompositeOp composite, + const std::string& custom_option_buffer) { + OpBuilder builder(composite->getContext()); + builder.setInsertionPoint(composite); + return builder.create( + composite->getLoc(), composite->getResultTypes(), + composite->getOperands(), composite.getName(), + CustomOption(&builder, custom_option_buffer)); +} + +} // namespace + +// Legalize stablehlo::CompositeOp to TFL::CustomOp for runtime-supported +// composites. See `IsSupportedComposite` for list of supported ops. +// +// Example: +// %0 = stablehlo.composite "odml.some_op" { +// composite_attrs = {}, +// version = 0 : i32 +// } +// ==> +// %0 = tfl.custom() { +// custom_code = "odml.some_op", +// custom_option = #tfl +// } +struct LegalizeCompositeToCustomOpPass + : public impl::LegalizeCompositeToCustomOpPassBase< + LegalizeCompositeToCustomOpPass> { + using LegalizeCompositeToCustomOpPassBase:: + LegalizeCompositeToCustomOpPassBase; + + void runOnOperation() override { + func::FuncOp fn = getOperation(); + fn.walk([&](Operation* op) { + // Process only StableHLO composite ops. + auto composite = llvm::dyn_cast(op); + if (!composite || !IsSupportedComposite(composite)) return; + + // Build flexbuffer options. + std::string custom_option_buffer; + auto fbb = std::make_unique(); + size_t map_start = fbb->StartMap(); + for (const NamedAttribute& pair : composite.getCompositeAttributes()) { + // Allows skipping unsupported attributes, will warn. + (void)BuildOption(fbb.get(), op, pair); + } + fbb->EndMap(map_start); + fbb->Finish(); + custom_option_buffer.assign(fbb->GetBuffer().begin(), + fbb->GetBuffer().end()); + + // Build TFL custom op, replace composite with custom op. + TFL::CustomOp tfl_custom_op = + BuildCustomOp(composite, custom_option_buffer); + composite->replaceAllUsesWith(tfl_custom_op); + composite->erase(); + }); + } +}; + +static PassRegistration pass; + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc new file mode 100644 index 00000000000000..4cfb0e04e96af4 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc @@ -0,0 +1,110 @@ +/* Copyright 2022 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" + +namespace mlir { +namespace odml { + +#define GEN_PASS_DEF_LEGALIZESTABLEHLOCUSTOMCALLTOCOMPOSITEPASS +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" + +struct ReplaceCustomCallWithComposite final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + explicit ReplaceCustomCallWithComposite(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(mlir::stablehlo::CustomCallOp op, + PatternRewriter &rewriter) const override { + auto backendConfig = + op->getAttr("composite.backend_config").dyn_cast(); + if (!backendConfig) + return op->emitError( + "custom_call has no 'composite.backend_config' attribute or the " + "attribute is not a dictionary"); + + auto name = backendConfig.get("name").dyn_cast(); + if (!name) + return op->emitError( + "backend_config has no 'name' key or the name value is not a string"); + + auto attrs = backendConfig.get("attributes").dyn_cast(); + if (!attrs) + return op->emitError( + "backend_config has no 'attributes' key or the attributes value is " + "not a dictionary"); + + auto calledComputations = op.getCalledComputations(); + if (!calledComputations || calledComputations.size() != 1) + return op->emitError("expected exactly one called_computation"); + + auto decomposition = calledComputations[0].cast(); + + auto composite = rewriter.create( + op.getLoc(), op.getResultTypes(), op.getOperands(), name.str(), attrs, + decomposition.getValue()); + rewriter.replaceOp(op, composite.getResults()); + return success(); + } +}; + +struct LegalizeStablehloCustomCallToCompositePass + : public impl::LegalizeStablehloCustomCallToCompositePassBase< + LegalizeStablehloCustomCallToCompositePass> { + using LegalizeStablehloCustomCallToCompositePassBase:: + LegalizeStablehloCustomCallToCompositePassBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalDialect(); + target.addDynamicallyLegalOp( + [&](mlir::stablehlo::CustomCallOp op) { + return op.getCallTargetName() != "stablehlo.composite"; + }); + + RewritePatternSet patterns(context); + patterns.add(context); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +static PassRegistration + pass_shlo_sc2c; + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h index 8df2d3503f3632..49e8b673f63374 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h @@ -59,6 +59,9 @@ std::unique_ptr> CreateUnfoldSplatConstantPass(); // Create a pass that legalizes MHLO to TFLite dialect. std::unique_ptr> CreateLegalizeHloToTfLitePass(); +// Creates a pass that lowers stablehlo composite ops to tflite ops. +std::unique_ptr> CreateCompositeLoweringPass(); + // Adds the HLO to TF rewrite patterns to the specified pattern list. void PopulateLegalizeHloToTfPatterns(RewritePatternSet* patterns, MLIRContext* context); @@ -67,8 +70,7 @@ void PopulateLegalizeHloToTfPatterns(RewritePatternSet* patterns, void PopulateLegalizeHloToTFLitePatterns(RewritePatternSet* patterns, MLIRContext* context); -#define GEN_PASS_DECL_LEGALIZESTABLEHLOTOVHLOPASS -#define GEN_PASS_DECL_LEGALIZEVHLOTOSTABLEHLOPASS +#define GEN_PASS_DECL #define GEN_PASS_REGISTRATION #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td index 002990601a9efb..a535d3aa867c80 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td @@ -95,8 +95,22 @@ def LegalizeVhloToStablehloPass : Pass<"vhlo-legalize-stablehlo", "ModuleOp"> { let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; } +def LegalizeCompositeToCustomOpPass : Pass<"stablehlo-composite-legalize-tfl-custom", "func::FuncOp"> { + let summary = "Legalize supported StableHLO CompositeOps to TFL CustomOp"; + let dependentDialects = ["TFL::TensorFlowLiteDialect"]; +} +def LegalizeStablehloCustomCallToCompositePass : Pass<"stablehlo-custom-call-legalize-composite", "ModuleOp"> { + let summary = "Legalize StableHLO custom call ops where the call target is 'stablehlo.composite' to composite ops."; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} def UnfoldSplatConstantPass : Pass<"unfold-splat-constant-pass", "ModuleOp"> { let summary = "Replaces a splat constant tensor with a BroadcastInDim op."; let constructor = "mlir::odml::CreateUnfoldSplatConstantPass()"; } + +def CompositeLoweringPass : Pass<"composite-lowering", "ModuleOp"> { + let summary = "Lowers mhlo composites directly to tflite ops (when possible)."; + let dependentDialects = ["mlir::mhlo::MhloDialect", "TFL::TensorFlowLiteDialect"]; + let constructor = "mlir::odml::CreateCompositeLoweringPass()"; +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc index 033ec78751e6b6..06754ea72b580c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project @@ -70,6 +71,9 @@ class SmuggleDisallowedOpsPass StringRef getDescription() const final { return "Smuggle disallowed ops via stablehlo.custom_calls"; } + void getDependentDialects(DialectRegistry& registry) const final { + registry.insert(); + } void runOnOperation() override { RewritePatternSet patterns(&getContext()); @@ -77,7 +81,7 @@ class SmuggleDisallowedOpsPass patterns.add>(&getContext()); ConversionTarget target(getContext()); - target.addIllegalDialect(); + target.addIllegalOp(); target.addLegalDialect(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.cc deleted file mode 100644 index b120ca89c290d4..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.cc +++ /dev/null @@ -1,279 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.h" - -#include -#include -#include - -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "flatbuffers/flexbuffers.h" // from @flatbuffers -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/Diagnostics.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" - -namespace mlir { -namespace odml { - -class StablehloToTflPass - : public mlir::PassWrapper> { - public: - explicit StablehloToTflPass() : PassWrapper() {} - StringRef getArgument() const final { return "stablehlo-tfl"; } - StringRef getDescription() const final { - return "This pass will legalize StableHLO Ops to TFLite custom Ops."; - } - - private: - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } - inline TFL::ConstBytesAttr CustomOption(OpBuilder* builder, - const std::string& content) { - return TFL::ConstBytesAttr::get(builder->getContext(), - StringRef(content.data(), content.size())); - } - - void AddIntegerArray(flexbuffers::Builder* fbb, - ::llvm::ArrayRef vec) { - auto start_input_dim = fbb->StartVector(); - for (auto int_value : vec) { - fbb->Add(int_value); - } - fbb->EndVector(start_input_dim, /*typed=*/false, /*fixed=*/false); - } -}; - -void StablehloToTflPass::runOnOperation() { - func::FuncOp fn = getOperation(); - OpBuilder builder(fn.getContext()); - fn.walk([&](Operation* op) { - // Process only StableHLO ops. - if (op->getDialect()->getNamespace() != "stablehlo") return; - - // Build options. - std::string custom_option_buffer; - auto fbb = std::make_unique(); - size_t map_start = fbb->StartMap(); - for (auto pair : op->getAttrDictionary().getValue()) { - const char* key = pair.getName().data(); - const auto attr = pair.getValue(); - - if (attr.isa<::mlir::IntegerAttr>()) { - fbb->Int(key, attr.dyn_cast().getInt()); - continue; - } - - if (attr.isa<::mlir::FloatAttr>()) { - fbb->Double(key, attr.dyn_cast().getValueAsDouble()); - continue; - } - - if (attr.isa<::mlir::ElementsAttr>()) { - auto start = fbb->StartVector(key); - auto array_attr = attr.dyn_cast(); - const auto ftype = array_attr.getElementType(); - if (ftype.isInteger(16) || ftype.isInteger(32) || ftype.isInteger(64) || - ftype.isInteger(128) || ftype.isInteger(1)) { - for (auto value : array_attr.getValues()) { - auto int_value = - value.dyn_cast_or_null().getInt(); - fbb->Add(int_value); - } - } else if (ftype.isF32() || ftype.isF64() || ftype.isF128()) { - for (auto value : array_attr.getValues()) { - auto double_value = - value.dyn_cast_or_null().getValueAsDouble(); - fbb->Add(double_value); - } - } else { - emitWarning(op->getLoc(), "serialization of ElementsAttr for ") - << key << " only supports Integer and Float."; - } - fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); - continue; - } - - if (attr.isa<::mlir::DenseI64ArrayAttr>()) { - auto array_attr = attr.dyn_cast(); - auto start = fbb->StartVector(key); - for (auto int_value : array_attr.asArrayRef()) { - fbb->Add(int_value); - } - fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); - continue; - } - - if (attr.isa<::mlir::DenseBoolArrayAttr>()) { - auto array_attr = attr.dyn_cast(); - auto start = fbb->StartVector(key); - for (auto bool_value : array_attr.asArrayRef()) { - fbb->Add(bool_value); - } - fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); - continue; - } - - if (attr.isa<::mlir::StringAttr>()) { - fbb->String(key, attr.dyn_cast().data()); - continue; - } - - if (attr.isa<::mlir::ArrayAttr>()) { - auto start = fbb->StartVector(key); - auto array_attr = attr.dyn_cast(); - if (array_attr.size() > 1 && !array_attr[0].isa() && - !array_attr[0].isa()) { - emitWarning(op->getLoc(), "serialization of ArrayAttr for ") - << key << " only supports Strings."; - continue; - } - for (auto value : array_attr) { - if (value.isa()) { - auto string_value = - mlir::stablehlo::stringifyPrecision( - value.cast().getValue()) - .data(); - fbb->Add(string_value); - } else { - auto string_value = - value.dyn_cast_or_null().data(); - fbb->Add(string_value); - } - } - fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); - continue; - } - - if (attr.isa<::mlir::stablehlo::ConvDimensionNumbersAttr>()) { - auto dimension_attr = - attr.dyn_cast<::mlir::stablehlo::ConvDimensionNumbersAttr>(); - auto start = fbb->StartVector(key); - fbb->Add(dimension_attr.getInputBatchDimension()); - fbb->Add(dimension_attr.getInputFeatureDimension()); - AddIntegerArray(fbb.get(), dimension_attr.getInputSpatialDimensions()); - fbb->Add(dimension_attr.getKernelInputFeatureDimension()); - fbb->Add(dimension_attr.getKernelOutputFeatureDimension()); - AddIntegerArray(fbb.get(), dimension_attr.getKernelSpatialDimensions()); - fbb->Add(dimension_attr.getOutputBatchDimension()); - fbb->Add(dimension_attr.getOutputFeatureDimension()); - AddIntegerArray(fbb.get(), dimension_attr.getOutputSpatialDimensions()); - fbb->EndVector(start, /*typed=*/false, /*fixed=*/false); - continue; - } - - if (attr.isa<::mlir::stablehlo::GatherDimensionNumbersAttr>()) { - auto dimension_attr = - attr.dyn_cast<::mlir::stablehlo::GatherDimensionNumbersAttr>(); - auto start = fbb->StartVector(key); - AddIntegerArray(fbb.get(), dimension_attr.getOffsetDims()); - AddIntegerArray(fbb.get(), dimension_attr.getCollapsedSliceDims()); - AddIntegerArray(fbb.get(), dimension_attr.getStartIndexMap()); - fbb->Add(dimension_attr.getIndexVectorDim()); - fbb->EndVector(start, /*typed=*/false, /*fixed=*/false); - continue; - } - - if (attr.isa<::mlir::stablehlo::ScatterDimensionNumbersAttr>()) { - auto dimension_attr = - attr.dyn_cast<::mlir::stablehlo::ScatterDimensionNumbersAttr>(); - auto start = fbb->StartVector(key); - AddIntegerArray(fbb.get(), dimension_attr.getUpdateWindowDims()); - AddIntegerArray(fbb.get(), dimension_attr.getInsertedWindowDims()); - AddIntegerArray(fbb.get(), - dimension_attr.getScatterDimsToOperandDims()); - fbb->Add(dimension_attr.getIndexVectorDim()); - fbb->EndVector(start, /*typed=*/false, /*fixed=*/false); - continue; - } - - if (attr.isa<::mlir::stablehlo::DotDimensionNumbersAttr>()) { - auto dimension_attr = - attr.dyn_cast<::mlir::stablehlo::DotDimensionNumbersAttr>(); - auto start = fbb->StartVector(key); - AddIntegerArray(fbb.get(), dimension_attr.getLhsBatchingDimensions()); - AddIntegerArray(fbb.get(), dimension_attr.getRhsBatchingDimensions()); - AddIntegerArray(fbb.get(), - dimension_attr.getLhsContractingDimensions()); - AddIntegerArray(fbb.get(), - dimension_attr.getRhsContractingDimensions()); - fbb->EndVector(start, /*typed=*/false, /*fixed=*/false); - continue; - } - - if (attr.isa<::mlir::stablehlo::ComparisonDirectionAttr>()) { - auto string_value = - mlir::stablehlo::stringifyComparisonDirection( - attr.cast() - .getValue()) - .str(); - fbb->String(key, string_value); - continue; - } - - if (attr.isa<::mlir::stablehlo::ComparisonTypeAttr>()) { - auto string_value = - mlir::stablehlo::stringifyComparisonType( - attr.cast().getValue()) - .str(); - fbb->String(key, string_value); - continue; - } - - // default - emitWarning(op->getLoc(), "serialization not supported for : ") << key; - } - fbb->EndMap(map_start); - fbb->Finish(); - custom_option_buffer.assign(fbb->GetBuffer().begin(), - fbb->GetBuffer().end()); - - // Build custom op. - builder.setInsertionPoint(op); - auto tfl_custom_op = builder.create( - op->getLoc(), op->getResultTypes(), op->getOperands(), - op->getName().getStringRef(), - CustomOption(&builder, custom_option_buffer)); - op->replaceAllUsesWith(tfl_custom_op); - op->erase(); - }); -} -std::unique_ptr> CreateStablehloToTflPass() { - return std::make_unique(); -} - -static PassRegistration pass; - -} // namespace odml -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.h deleted file mode 100644 index 9445b770f10562..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_TFL_PASS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_TFL_PASS_H_ - -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project - -namespace mlir { -namespace odml { - -// Creates a pass which transforms StableHLO Ops to TFL Ops. -std::unique_ptr> CreateStablehloToTflPass(); - -} // namespace odml -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_TFL_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc index e0a1dbc3bf9445..fcacfcf4984db1 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/mlir_hlo/mhlo/transforms/rewriters.h" +#include "xla/mlir_hlo/mhlo/utils/type_conversion.h" namespace mlir { namespace odml { @@ -104,8 +105,8 @@ void TFToMhloPass::runOnOperation() { mhlo::Tf2XlaTypeConverter converter; mhlo::PopulateLegalizeTfWithTf2XlaPatterns( "XLA_CPU_JIT", patterns, context, converter, /*prefer_tf2xla=*/false); - chlo::populateDecomposeChloPatterns(context, &patterns); - chlo::populateChloBroadcastingPatterns(context, &patterns); + stablehlo::StablehloToHloTypeConverter hlo_converter; + chlo::populateChloToHloPatterns(context, &hlo_converter, &patterns); chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context); ConversionTarget target(*context); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index aaa1236a6d9470..8fed8f3f01ed54 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -62,6 +62,7 @@ using ::mlir::quant::CreateI32F32UniformQuantizedType; using ::mlir::quant::CreateI8F32UniformQuantizedPerAxisType; using ::mlir::quant::CreateI8F32UniformQuantizedType; using ::mlir::quant::FindUserOfType; +using ::mlir::quant::GetElementType; using ::mlir::quant::IsI32F32UniformQuantizedPerAxisType; using ::mlir::quant::IsI32F32UniformQuantizedType; using ::mlir::quant::IsI8F32UniformQuantizedPerAxisType; @@ -142,10 +143,7 @@ TFL::QConstOp CreateTransposedTflConstOpForFilter( Type new_filter_quantized_type; if (is_per_channel) { - auto filter_quantized_type = filter_constant_op.getResult() - .getType() - .cast() - .getElementType() + auto filter_quantized_type = GetElementType(filter_constant_op.getResult()) .cast(); new_filter_quantized_type = CreateI8F32UniformQuantizedPerAxisType( filter_constant_op->getLoc(), *rewriter.getContext(), @@ -153,10 +151,7 @@ TFL::QConstOp CreateTransposedTflConstOpForFilter( filter_quantized_type.getZeroPoints(), /*quantization_dimension=*/0, /*narrow_range=*/true); } else { - auto filter_quantized_type = filter_constant_op.getResult() - .getType() - .cast() - .getElementType() + auto filter_quantized_type = GetElementType(filter_constant_op.getResult()) .cast(); new_filter_quantized_type = CreateI8F32UniformQuantizedType( filter_constant_op->getLoc(), *rewriter.getContext(), @@ -224,9 +219,7 @@ TFL::QConstOp CreateTflConstOpForDummyBias( Type bias_quantized_type; if (is_per_channel) { const auto filter_quantized_element_type = - filter_const_op.getResult() - .getType() - .getElementType() + GetElementType(filter_const_op.getResult()) .cast(); // The storage type is i32 for bias, which is the precision used for @@ -238,9 +231,7 @@ TFL::QConstOp CreateTflConstOpForDummyBias( /*quantization_dimension=*/0); } else { const auto filter_quantized_element_type = - filter_const_op.getResult() - .getType() - .getElementType() + GetElementType(filter_const_op.getResult()) .cast(); // The storage type is i32 for bias, which is the precision used for @@ -277,8 +268,9 @@ arith::ConstantOp CreateI32ShapeConstantOp(const TensorType op_type, } // Returns the desired qi8 per-tensor quantized output type for a given gemm op. -Type GetOutputType(Operation* op, MLIRContext& ctx, const bool has_i32_output, - const bool fuse_bias_constant) { +Type GetQuantizedOutputType(Operation* op, PatternRewriter& rewriter, + const bool has_i32_output, + const bool fuse_bias_constant) { Operation* uniform_quantize_op; if (!has_i32_output) return op->getResult(0).getType(); if (fuse_bias_constant) { @@ -289,17 +281,15 @@ Type GetOutputType(Operation* op, MLIRContext& ctx, const bool has_i32_output, } // StableHLO Quantizer outputs an i32 type. Rewrite to i8 type result // to meet TFLite op requirement. - auto result_quantized_type = uniform_quantize_op->getResult(0) - .getType() - .cast() - .getElementType() + auto result_quantized_type = GetElementType(uniform_quantize_op->getResult(0)) .cast(); auto new_result_quantized_type = CreateI8F32UniformQuantizedType( - uniform_quantize_op->getLoc(), ctx, result_quantized_type.getScale(), - result_quantized_type.getZeroPoint()); + uniform_quantize_op->getLoc(), *rewriter.getContext(), + result_quantized_type.getScale(), result_quantized_type.getZeroPoint()); // Omit any bias and requantize ops as `tfl.{gemm_op}` outputs a // fused `qi8` type. - FindUserOfType<>(uniform_quantize_op)->setOperand(0, op->getResult(0)); + rewriter.replaceAllUsesWith(uniform_quantize_op->getResult(0), + op->getResult(0)); return op->getResult(0).getType().cast().clone( new_result_quantized_type); } @@ -315,8 +305,7 @@ class RewriteUniformQuantizeOp // detailed limitations // (https://github.com/tensorflow/tensorflow/blob/8f145d579aa0ee7f4187af32dbbf4e12fdabbffe/tensorflow/lite/kernels/quantize.cc#L105). LogicalResult match(stablehlo::UniformQuantizeOp op) const override { - const Type input_element_type = - op.getOperand().getType().cast().getElementType(); + const Type input_element_type = GetElementType(op.getOperand()); if (!(input_element_type.isa() || IsI32F32UniformQuantizedType(input_element_type) || IsI32F32UniformQuantizedPerAxisType(input_element_type))) { @@ -328,10 +317,7 @@ class RewriteUniformQuantizeOp // Output type of `UniformQuantizeOp` is guaranteed to be a quantized // tensor with integer storage type. - const auto output_storage_type = op.getResult() - .getType() - .cast() - .getElementType() + const auto output_storage_type = GetElementType(op.getResult()) .cast() .getStorageType() .cast(); @@ -363,10 +349,7 @@ class RewriteUniformDequantizeOp // detailed limitations // (https://github.com/tensorflow/tensorflow/blob/8f145d579aa0ee7f4187af32dbbf4e12fdabbffe/tensorflow/lite/kernels/dequantize.cc#L52). LogicalResult match(stablehlo::UniformDequantizeOp op) const override { - const auto input_storage_type = op.getOperand() - .getType() - .cast() - .getElementType() + const auto input_storage_type = GetElementType(op.getOperand()) .cast() .getStorageType() .cast(); @@ -377,11 +360,8 @@ class RewriteUniformDequantizeOp } // Output type is guaranteed to be a float tensor for a valid StableHLO. - const auto output_element_type = op.getResult() - .getType() - .cast() - .getElementType() - .cast(); + const auto output_element_type = + GetElementType(op.getResult()).cast(); if (!output_element_type.isa()) { LLVM_DEBUG(llvm::dbgs() << "Uniform dequantize op's output element type " "should be f32. Got: " @@ -448,8 +428,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp op.getDotDimensionNumbers(); const bool is_batch_matmul = !dot_dimension_nums.getLhsBatchingDimensions().empty(); - const Type elem_type = - op.getResult().getType().cast().getElementType(); + const Type elem_type = GetElementType(op.getResult()); const bool has_i32_output = IsI32F32UniformQuantizedType(elem_type) || IsI32F32UniformQuantizedPerAxisType(elem_type); @@ -479,8 +458,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp void rewrite(stablehlo::DotGeneralOp op, PatternRewriter& rewriter) const override { - const Type output_type = - op.getResult().getType().cast().getElementType(); + const Type output_type = GetElementType(op.getResult()); const bool has_i32_output = IsI32F32UniformQuantizedType(output_type) || IsI32F32UniformQuantizedPerAxisType(output_type); @@ -656,8 +634,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp static LogicalResult MatchOutput(const Value output, const bool has_i32_output, const bool is_batch_matmul) { - const Type output_element_type = - output.getType().cast().getElementType(); + const Type output_element_type = GetElementType(output); if (has_i32_output) { if (is_batch_matmul && !IsI32F32UniformQuantizedType(output_element_type)) { @@ -760,11 +737,8 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp TFL::QConstOp filter_constant_op = CreateTflConstOpForFilter( rhs_value.getDefiningOp(), rewriter, /*is_per_channel=*/true); - const double input_scale = lhs_value.getType() - .cast() - .getElementType() - .cast() - .getScale(); + const double input_scale = + GetElementType(lhs_value).cast().getScale(); TFL::QConstOp bias_tfl_op; bool fuse_bias_constant = FindUserOfType(op) && has_i32_output; @@ -800,16 +774,10 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp Operation* add_op = FindUserOfType(op); uniform_quantize_op = FindUserOfType(add_op); const auto filter_quantized_type = - op->getOperand(1) - .getType() - .cast() - .getElementType() + GetElementType(op->getOperand(1)) .cast(); const SmallVector bias_scales = GetBiasScales( - /*input_scale=*/op->getOperand(0) - .getType() - .cast() - .getElementType() + /*input_scale=*/GetElementType(op->getOperand(0)) .cast() .getScale(), /*filter_scales=*/filter_quantized_type.getScales()); @@ -821,10 +789,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp const auto bias_quantized_type = CreateI32F32UniformQuantizedPerAxisType( op->getLoc(), *op->getContext(), std::move(bias_scales), - op->getResult(0) - .getType() - .cast() - .getElementType() + GetElementType(op->getResult(0)) .cast() .getZeroPoints(), /*quantization_dimension=*/0); @@ -841,11 +806,9 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp uniform_quantize_op = FindUserOfType(op); } - const auto result_quantized_type = uniform_quantize_op->getResult(0) - .getType() - .cast() - .getElementType() - .cast(); + const auto result_quantized_type = + GetElementType(uniform_quantize_op->getResult(0)) + .cast(); const auto new_result_quantized_type = CreateI8F32UniformQuantizedType( uniform_quantize_op->getLoc(), *rewriter.getContext(), result_quantized_type.getScale(), @@ -856,8 +819,8 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp // fused `qi8` type. FindUserOfType<>(uniform_quantize_op)->setOperand(0, op->getResult(0)); } else { - output_type = GetOutputType(op, *rewriter.getContext(), has_i32_output, - fuse_bias_constant); + output_type = GetQuantizedOutputType(op, rewriter, has_i32_output, + fuse_bias_constant); } return output_type; } @@ -898,8 +861,8 @@ class RewriteQuantizedConvolutionOp public: using OpRewritePattern::OpRewritePattern; LogicalResult match(stablehlo::ConvolutionOp op) const override { - const bool has_i32_output = IsI32F32UniformQuantizedPerAxisType( - op.getResult().getType().cast().getElementType()); + const bool has_i32_output = + IsI32F32UniformQuantizedPerAxisType(GetElementType(op.getResult())); const bool fuse_bias_constant = FindUserOfType(op) && has_i32_output; stablehlo::ConvDimensionNumbersAttr dimension_numbers = @@ -965,8 +928,8 @@ class RewriteQuantizedConvolutionOp void rewrite(stablehlo::ConvolutionOp op, PatternRewriter& rewriter) const override { - const bool has_i32_output = IsI32F32UniformQuantizedPerAxisType( - op.getResult().getType().cast().getElementType()); + const bool has_i32_output = + IsI32F32UniformQuantizedPerAxisType(GetElementType(op.getResult())); stablehlo::ConvDimensionNumbersAttr dimension_numbers = op.getDimensionNumbers(); @@ -993,8 +956,8 @@ class RewriteQuantizedConvolutionOp input_value = pad_op.getResult(); } - const Type output_type = GetOutputType(op, *rewriter.getContext(), - has_i32_output, fuse_bias_constant); + const Type output_type = GetQuantizedOutputType( + op, rewriter, has_i32_output, fuse_bias_constant); const auto [stride_h, stride_w] = GetStrides(op); const auto [dilation_h_factor, dilation_w_factor] = GetDilationFactors(op); if (is_depthwise) { @@ -1110,8 +1073,7 @@ class RewriteQuantizedConvolutionOp } static LogicalResult MatchOutput(Value output) { - const Type output_element_type = - output.getType().cast().getElementType(); + const Type output_element_type = GetElementType(output); if (!IsI32F32UniformQuantizedPerAxisType(output_element_type) && !IsI8F32UniformQuantizedType(output_element_type)) { LLVM_DEBUG( @@ -1290,7 +1252,7 @@ class RewriteQuantizedConvolutionOp // output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i]) auto get_output_dim_for_same_padding = [](int64_t input_dim, int64_t stride_dim) -> int64_t { - return std::ceil(input_dim / stride_dim); + return std::ceil(input_dim / static_cast(stride_dim)); }; return output_height == get_output_dim_for_same_padding(input_height, stride_height) && @@ -1397,10 +1359,7 @@ class RewriteQuantizedConvolutionOp Value filter_value = op.getOperand(1); Operation* filter_op = filter_value.getDefiningOp(); auto filter_uniform_quantized_type = - filter_value.getType() - .cast() - .getElementType() - .cast(); + GetElementType(filter_value).cast(); auto filter_constant_value_attr = cast( cast(filter_value.getDefiningOp()).getValue()); const DenseIntElementsAttr new_filter_value_attr = @@ -1440,10 +1399,7 @@ class RewriteQuantizedConvolutionOp const SmallVector bias_shape, const bool has_i32_output, const bool fuse_bias_constant) const { const SmallVector bias_scales = GetBiasScales( - /*input_scale=*/op.getOperand(0) - .getType() - .cast() - .getElementType() + /*input_scale=*/GetElementType(op.getOperand(0)) .cast() .getScale(), /*filter_scales=*/new_filter_quantized_type.getScales()); @@ -2032,20 +1988,186 @@ class RewriteQuantizedGatherOp : public OpRewritePattern { } }; +// Rewrites quantized stablehlo.dynamic_slice to tfl.slice. +// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. +class RewriteQuantizedDynamicSliceOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(stablehlo::DynamicSliceOp op) const override { + if (!IsQuantizedTensorType(op.getOperand().getType()) || + !IsQuantizedTensorType(op.getResult().getType())) { + return failure(); + } + + return success(quant::HasStaticShape(op.getOperand())); + } + + void rewrite(stablehlo::DynamicSliceOp op, + PatternRewriter& rewriter) const override { + Type output = op.getResult().getType(); + Value input = op.getOperand(); + TensorType operand_type = input.getType().cast(); + ArrayRef operand_shape = operand_type.getShape(); + const int64_t rank = operand_type.getRank(); + const Type i64_type = rewriter.getI64Type(); + + ArrayRef slice_sizes = op.getSliceSizes(); + TensorType single_element_type = + operand_type.cloneWith({static_cast(1)}, i64_type); + + SmallVector start_indices(rank); + for (auto [i, start_index] : llvm::enumerate(op.getStartIndices())) { + // Start indices should be casted from tensor to tensor<1xi64>. + auto cast = rewriter.create( + op->getLoc(), single_element_type, start_index); + int64_t upper_limit_idx = operand_shape[i] - slice_sizes[i]; + auto upper_limit_attr = + DenseIntElementsAttr::get(single_element_type, {upper_limit_idx}); + auto upper_limit_cst = + rewriter.create(op->getLoc(), upper_limit_attr); + // Dynamic start indices should be clamped with upper limit of + // `shape(operand) - slice_sizes)` as per semantics of + // `stablehlo.dynamic_slice`. + // (https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_slice) + start_indices[i] = + rewriter.create(op->getLoc(), cast, upper_limit_cst); + } + + Value concatenated = start_indices[0]; + if (rank > 1) { + SmallVector begin_shape{rank}; + Type begin_type = operand_type.cloneWith(begin_shape, i64_type); + concatenated = rewriter.create( + op->getLoc(), begin_type, start_indices, /*axis=*/0, + /*fused_activation_function=*/rewriter.getStringAttr("NONE")); + } + + // Clamp with lower limit. + auto lower_limit_attr = DenseIntElementsAttr::get( + single_element_type, {static_cast(0)}); + auto lower_limit_cst = + rewriter.create(op->getLoc(), lower_limit_attr); + // Dynamic start indices should be clamped with lower limit of + // 0 as per semantics of `stablehlo.dynamic_slice`. + // (https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_slice) + auto begin = rewriter.create(op->getLoc(), concatenated, + lower_limit_cst); + + SmallVector size_len{rank}; + TensorType size_type = operand_type.cloneWith(size_len, i64_type); + auto size_attr = DenseIntElementsAttr::get(size_type, slice_sizes); + auto size = rewriter.create(op.getLoc(), size_attr); + + rewriter.replaceOpWithNewOp(op, output, input, begin, size); + } +}; + +class RewriteQuantizedAddOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(stablehlo::AddOp op) const override { + return success(IsI8F32UniformQuantizedType(GetElementType(op.getLhs())) && + IsI8F32UniformQuantizedType(GetElementType(op.getRhs()))); + } + + void rewrite(stablehlo::AddOp op, PatternRewriter& rewriter) const override { + TFL::QConstOp lhs_qconst_op; + TFL::QConstOp rhs_qconst_op; + + auto GetBroadcastedConstOp = [&](Value operand) -> TFL::QConstOp { + if (auto broadcast_op = dyn_cast_or_null( + operand.getDefiningOp())) { + auto stablehlo_const_op = dyn_cast_or_null( + broadcast_op.getOperand().getDefiningOp()); + auto const_uniform_quantized_type = + stablehlo_const_op.getResult().getType().cast(); + return rewriter.create( + op.getLoc(), TypeAttr::get(const_uniform_quantized_type), + cast(stablehlo_const_op.getValue())); + } + return nullptr; + }; + + lhs_qconst_op = GetBroadcastedConstOp(op.getLhs()); + rhs_qconst_op = GetBroadcastedConstOp(op.getRhs()); + + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), + lhs_qconst_op ? lhs_qconst_op : op.getOperand(0), + rhs_qconst_op ? rhs_qconst_op : op.getOperand(1), + /*fused_activation_function=*/rewriter.getStringAttr("NONE")); + } +}; + +// Rewrites quantized `stablehlo.constant` to `tfl.pseudo_qconst`. +class RewriteQuantizedConstantOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(stablehlo::ConstantOp op) const override { + return success(IsQuantizedTensorType(op.getOutput().getType())); + } + + void rewrite(stablehlo::ConstantOp op, + PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp( + op, /*qtype=*/TypeAttr::get(op.getOutput().getType()), + /*value=*/op.getValue()); + } +}; + +// Splits dot-like hybrid quantized StableHLO ops into `tfl.dequantize` and +// float StableHLO op. Legalization of float StableHLO op depends on existing +// passes for conversion of StableHLO -> MHLO -> TF -> TFL. +template +class RewriteHybridQuantizedDotLikeOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(OpType op) const override { + if (op->getNumOperands() != 2 || op->getNumResults() != 1) { + return failure(); + } + // Lhs and result should not be quantized and rhs should be quantized. + return success(!IsQuantizedTensorType(op->getOperand(0).getType()) && + IsQuantizedTensorType(op->getOperand(1).getType()) && + !IsQuantizedTensorType(op->getResult(0).getType())); + } + + void rewrite(OpType op, PatternRewriter& rewriter) const override { + Value rhs = op.getOperand(1); + Type lhs_element_type = + op.getOperand(0).getType().template cast().getElementType(); + Type dequantized_rhs_type = + quant::CloneTypeWithNewElementType(rhs.getType(), lhs_element_type); + auto dq = rewriter.create( + op->getLoc(), /*output=*/dequantized_rhs_type, + /*input=*/rhs); + rewriter.replaceAllUsesExcept(rhs, dq.getOutput(), dq); + } +}; + void UniformQuantizedStableHloToTflPass::runOnOperation() { func::FuncOp func_op = getOperation(); MLIRContext& ctx = getContext(); RewritePatternSet patterns(&ctx); - patterns.add, + RewriteHybridQuantizedDotLikeOp, + RewriteUniformDequantizeOp, RewriteUniformQuantizeOp, + RewriteQuantizedAddOp, RewriteQuantizedBroadcastInDimOp, + RewriteQuantizedConcatenateOp, RewriteQuantizedConstantOp, + RewriteQuantizedConvolutionOp, RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp, - RewriteQuantizedConvolutionOp, RewriteQuantizedTransposeOp, - RewriteQuantizedReshapeOp, RewriteQuantizedSelectOp, - RewriteQuantizedConcatenateOp, RewriteQuantizedPadOp, - RewriteQuantizedSliceOp, RewriteQuantizedBroadcastInDimOp, - RewriteQuantizedReduceWindowOpWithMax, - RewriteQuantizedDynamicReshapeOp, RewriteQuantizedGatherOp>( - &ctx); + RewriteQuantizedDynamicReshapeOp, RewriteQuantizedDynamicSliceOp, + RewriteQuantizedGatherOp, RewriteQuantizedPadOp, + RewriteQuantizedReduceWindowOpWithMax, RewriteQuantizedReshapeOp, + RewriteQuantizedSelectOp, RewriteQuantizedSliceOp, + RewriteQuantizedTransposeOp>(&ctx); if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns)))) { func_op.emitError() << "Failed to convert stablehlo ops with uniform " diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD index e1687b22816be0..2afbe2a0d2c766 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD @@ -1,6 +1,6 @@ +load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/cast_bf16.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/cast_bf16.mlir new file mode 100644 index 00000000000000..56068d605016e7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/cast_bf16.mlir @@ -0,0 +1,12 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s +// Ensure cast with bfloat16 roundtrip exactly + +func.func @main(tensor<4x5xbf16>) -> tensor<4x5xbf16> { +^bb0(%arg0: tensor<4x5xbf16>): + // CHECK-LABEL: @main + // CHECK: (tensor<4x5xbf16>) -> tensor<4x5xf32> + // CHECK-NEXT: (tensor<4x5xf32>) -> tensor<4x5xbf16> + %0 = "tfl.cast" (%arg0) : (tensor<4x5xbf16>) -> tensor<4x5xf32> loc("cast1") + %1 = "tfl.cast" (%0) : (tensor<4x5xf32>) -> tensor<4x5xbf16> loc("cast2") + func.return %1 : tensor<4x5xbf16> +} diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 685efd5be0ca2d..a0b9f90a879507 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1875,6 +1875,18 @@ func.func @matmul_batchv3_unknown_dim(%arg0: tensor, %arg1: tensor< // CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor, tensor<15x17xf32>) -> tensor } +func.func @matmul_batchv3_unknown_dim_bf16(%arg0: tensor, %arg1: tensor<5x6xf32>) -> tensor { + %0 = "tf.Cast"(%arg0) : (tensor) -> tensor + %1 = "tf.BatchMatMulV3"(%0, %arg1) {Ta = "tfdtype$DT_FLOAT", Tb = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : +(tensor, tensor<5x6xf32>) -> tensor + %2 = "tf.Cast"(%1) : (tensor) -> tensor + func.return %2 : tensor +// CHECK-LABEL: matmul_batchv3_unknown_dim_bf16 +// CHECK: [[CST:%.*]] = "tfl.cast"(%arg0) : (tensor) -> tensor +// CHECK: [[BMM:%.*]] = "tfl.batch_matmul"([[CST]], %arg1) {adj_x = false, adj_y = false} : (tensor, tensor<5x6xf32>) -> tensor +// CHECK: "tfl.cast"([[BMM]]) : (tensor) -> tensor +} + // ----- func.func @select_v2_with_6d_broadcasting(%arg0: tensor<1x1x1x1x3x1xi1>, %arg1 : tensor<1x1x1x1x1x4xf32>, %arg2 : tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/legalize_jax_random.mlir b/tensorflow/compiler/mlir/lite/tests/legalize_jax_random.mlir index d7d77f2e77a97b..76f453d1d3a8aa 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize_jax_random.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize_jax_random.mlir @@ -3,31 +3,31 @@ // CHECK-LABEL: func @tfl_wrapped_jax_random_normal( // CHECK-SAME: %[[RNG:.*]]: tensor<2xui32>) -> tuple> { -// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<[3, 4]> : tensor<2xi32> +// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<[3, 4]> : tensor<2xi32> // CHECK: %[[VAL_1:.*]] = "tfl.custom"(%[[VAL_0]]) {custom_code = "RandomStandardNormal", custom_option = #tfl} : (tensor<2xi32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_2:.*]] = mhlo.tuple %[[VAL_1]] : tuple> +// CHECK: %[[VAL_2:.*]] = stablehlo.tuple %[[VAL_1]] : tuple> // CHECK: return %[[VAL_2]] : tuple> // CHECK: } func.func @tfl_wrapped_jax_random_normal(%arg0: tensor<2xui32>) -> tuple> { // This is a fake jax random normal body. - %0 = mhlo.constant dense<0.0> : tensor<12xf32> - %1 = "mhlo.reshape"(%0) : (tensor<12xf32>) -> tensor<3x4xf32> - %2 = "mhlo.tuple"(%1) : (tensor<3x4xf32>) -> tuple> + %0 = stablehlo.constant dense<0.0> : tensor<12xf32> + %1 = "stablehlo.reshape"(%0) : (tensor<12xf32>) -> tensor<3x4xf32> + %2 = "stablehlo.tuple"(%1) : (tensor<3x4xf32>) -> tuple> func.return %2 : tuple> } // CHECK-LABEL: func @tfl_wrapped_jax_random_uniform( // CHECK-SAME: %[[RNG:.*]]: tensor<2xui32>) -> tuple> { -// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<[1, 2]> : tensor<2xi32> +// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<[1, 2]> : tensor<2xi32> // CHECK: %[[VAL_1:.*]] = "tfl.custom"(%[[VAL_0]]) {custom_code = "RandomUniform", custom_option = #tfl} : (tensor<2xi32>) -> tensor<1x2xf32> -// CHECK: %[[VAL_2:.*]] = mhlo.tuple %[[VAL_1]] : tuple> +// CHECK: %[[VAL_2:.*]] = stablehlo.tuple %[[VAL_1]] : tuple> // CHECK: return %[[VAL_2]] : tuple> // CHECK: } func.func @tfl_wrapped_jax_random_uniform(%arg0: tensor<2xui32>) -> tuple> { // This is a fake jax random uniform body. - %0 = mhlo.constant dense<0.0> : tensor<2xf32> - %1 = "mhlo.reshape"(%0) : (tensor<2xf32>) -> tensor<1x2xf32> - %2 = "mhlo.tuple"(%1) : (tensor<1x2xf32>) -> tuple> + %0 = stablehlo.constant dense<0.0> : tensor<2xf32> + %1 = "stablehlo.reshape"(%0) : (tensor<2xf32>) -> tensor<1x2xf32> + %2 = "stablehlo.tuple"(%1) : (tensor<1x2xf32>) -> tuple> func.return %2 : tuple> } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/cast_bf16.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/cast_bf16.mlir new file mode 100644 index 00000000000000..83255ca39a4472 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/cast_bf16.mlir @@ -0,0 +1,74 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s + +func.func @main(tensor<4x5xbf16>) -> tensor<4x5xbf16> { +^bb0(%arg0: tensor<4x5xbf16>): + +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: deprecated_builtin_code: 53, +// CHECK-NEXT: version: 7, +// CHECK-NEXT: builtin_code: CAST +// CHECK-NEXT: } ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 4, 5 ], +// CHECK-NEXT: type: BFLOAT16, +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "arg0", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: has_rank: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4, 5 ], +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "cast1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: has_rank: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4, 5 ], +// CHECK-NEXT: type: BFLOAT16, +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "cast2", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: has_rank: true +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ 0 ], +// CHECK-NEXT: outputs: [ 2 ], +// CHECK-NEXT: operators: [ { +// CHECK-NEXT: inputs: [ 0 ], +// CHECK-NEXT: outputs: [ 1 ] +// CHECK-NEXT: }, { +// CHECK-NEXT: inputs: [ 1 ], +// CHECK-NEXT: outputs: [ 2 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: name: "main" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", +// CHECK-NEXT: buffers: [ { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 4 +// CHECK-NEXT: } ], +// CHECK-NEXT: signature_defs: [ ] +// CHECK-NEXT: } + + %0 = "tfl.cast" (%arg0) : (tensor<4x5xbf16>) -> tensor<4x5xf32> loc("cast1") + %1 = "tfl.cast" (%0) : (tensor<4x5xf32>) -> tensor<4x5xbf16> loc("cast2") + func.return %1 : tensor<4x5xbf16> +} diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 8548151458a26c..75c1a791eeca73 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -773,6 +773,22 @@ func.func @FuseReshapeAroundBMMLHS(%arg0: tensor<6x5x1024xf32>) -> tensor<6x5x81 // CHECK: return %0 : tensor<6x5x8192xf32> } +// CHECK-LABEL: @FuseReshapeAroundBMMLHSNegative +func.func @FuseReshapeAroundBMMLHSNegative(%arg0: tensor<1x64xf32>, %arg1: tensor<1x64x1024xf32> ) -> (tensor<1x1024xf32> ) { + %cst = arith.constant dense<[1, 1024]> : tensor<2xi32> + %cst_0 = arith.constant dense<[1, 1, 64]> : tensor<3xi32> + %0 = "tfl.reshape"(%arg0, %cst_0) : (tensor<1x64xf32>, tensor<3xi32>) -> tensor<1x1x64xf32> + %1 = "tfl.batch_matmul"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x1x64xf32>, tensor<1x64x1024xf32>) -> tensor<1x1x1024xf32> + %2 = "tfl.reshape"(%1, %cst) : (tensor<1x1x1024xf32>, tensor<2xi32>) -> tensor<1x1024xf32> + return %2 : tensor<1x1024xf32> + // CHECK: %cst = arith.constant dense<[1, 1024]> : tensor<2xi32> + // CHECK: %cst_0 = arith.constant dense<[1, 1, 64]> : tensor<3xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst_0) : (tensor<1x64xf32>, tensor<3xi32>) -> tensor<1x1x64xf32> + // CHECK: %1 = "tfl.batch_matmul"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x1x64xf32>, tensor<1x64x1024xf32>) -> tensor<1x1x1024xf32> + // CHECK: %2 = "tfl.reshape"(%1, %cst) : (tensor<1x1x1024xf32>, tensor<2xi32>) -> tensor<1x1024xf32> + // CHECK: return %2 : tensor<1x1024xf32> +} + // CHECK-LABEL: @FuseReshapeAroundBMMNagativeTest func.func @FuseReshapeAroundBMMNagativeTest(%arg0: tensor<5x4x1x1024xf32>, %arg1: tensor<5x1024x8192xf32>) -> tensor<5x4x1x8192xf32> { %cst = arith.constant dense_resource<__elided__> : tensor<3xi32> diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 655aee59b77378..f4aa97069655e8 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -144,21 +144,18 @@ void AddPreQuantizationStableHloToTfPasses( pass_manager.addPass( mlir::odml::CreateLegalizeTFXlaCallModuleToStablehloPass()); - // Add CHLO to StableHLO Decompositions: - // This is needed since we are relying on XlaCallModule uses MHLO - // specific features like mhlo::ErfOp which aren't supported - // in StableHLO, but we have CHLO->StableHLO decompositions to legalize. - pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); - pass_manager.addPass( - mlir::stablehlo::experimental::createChloRecomposeOpsPass()); - pass_manager.addNestedPass( - mlir::mhlo::createChloLegalizeToHloBasisOpsPass()); - pass_manager.addNestedPass( - mlir::mhlo::createChloLegalizeToHloPass()); - pass_manager.addNestedPass( - mlir::mhlo::createShapeLegalizeToHloPass()); + // Legalize MHLO to StableHLO should be moved closer to where it is needed + // There are some entry points that start with HLO->MHLO like + // jax_to_tfl_flatbuffer.cc which can likely be updated to emit StableHLO + // to be consistent with other entrypoints. pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + // Decompose CHLO into StableHLO ops + // TODO(b/331843141): There are some CHLO's like TopK which we could instead + // lower to TFL ops. + mlir::stablehlo::experimental::createChloLegalizeToStablehloPipeline( + pass_manager); + // The following two passes find specific uniform quantization patterns in // StableHLO and converts them to TFLite ops that accept or produce uniform // quantized types. They only target a specific set of models that contain @@ -174,7 +171,6 @@ void AddPreQuantizationStableHloToTfPasses( pass_manager.addNestedPass( mlir::odml::CreateUniformQuantizedStableHloToTflPass()); - pass_manager.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); // Legalize jax random to tflite custom op. // The CreateLegalizeJaxRandom Pass has to stay at because we need to replace // the random function body before being inlined. @@ -182,6 +178,7 @@ void AddPreQuantizationStableHloToTfPasses( mlir::TFL::CreateLegalizeJaxRandomPass()); // Canonicalize, CSE etc. + pass_manager.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); pass_manager.addNestedPass( mlir::createCanonicalizerPass()); pass_manager.addNestedPass(mlir::createCSEPass()); @@ -231,6 +228,10 @@ void AddPostQuantizationStableHloToTfPasses( pass_manager.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); } + if (pass_config.enable_composite_direct_lowering) { + pass_manager.addPass(mlir::odml::CreateCompositeLoweringPass()); + } + // TFLite dialect passes. if (!pass_config.disable_hlo_to_tfl_conversion) { pass_manager.addPass(mlir::odml::CreateLegalizeHloToTfLitePass()); @@ -252,6 +253,16 @@ void AddPostQuantizationStableHloToTfPasses( // Legalize all remaining mhlo ops to stableHLO pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + + // Translate "stablehlo.custom_call @stablehlo.composite" to + // "stablehlo.composite" + // TODO: b/330741524 - clean this up when "stablehlo.composite" is emitted + // directly. Additionally remove the composite to custom once ODML long term + // solution lands. + pass_manager.addPass( + mlir::odml::createLegalizeStablehloCustomCallToCompositePass()); + pass_manager.addNestedPass( + mlir::odml::createLegalizeCompositeToCustomOpPass()); } // This is the early part of the conversion in isolation. This enables a caller diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 7a3f06bb376784..ac4de7f82b23d0 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -497,6 +497,13 @@ absl::Status ConvertTFExecutorToTFLOrFlatbuffer( options.metadata.insert( MetadataForReducedPrecisionSupport(quant_specs.support_mask)); } + pass_manager.clear(); + pass_manager.addPass(mlir::odml::createLegalizeStablehloToVhloPass()); + if (failed(pass_manager.run(module))) { + return status_handler.Combine( + absl::InvalidArgumentError("VHLO lowering failed")); + } + if (!tflite::MlirToFlatBufferTranslateFunction( module, options, &translated_result, serialize_stablehlo_ops)) { return status_handler.Combine( diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index 208c20492c10f6..2f015e61d58fe6 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -27,10 +27,10 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/utils/utils.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc index 72120f1502f021..e8bae6eb64280f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc @@ -47,10 +47,10 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { namespace TFL { @@ -99,7 +99,7 @@ void LegalizeJaxRandomPass::runOnOperation() { } auto result_shape_attr = builder.getI32TensorAttr(result_shape_i32); Value result_shape_tensor = - builder.create(result_shape_attr); + builder.create(result_shape_attr); auto custom_code = IsJaxRandomUniform(func) ? "RandomUniform" : "RandomStandardNormal"; @@ -112,7 +112,7 @@ void LegalizeJaxRandomPass::runOnOperation() { ValueRange(result_shape_tensor_vec), custom_code, attr) .getResult(0); - Value tulple_result = builder.create(random_result); + Value tulple_result = builder.create(random_result); builder.create(tulple_result); } } // namespace diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 7f55ca054383fa..401f34e6e7943c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -49,9 +49,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index c96266da31ddc9..0b068972c8fd30 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -1529,7 +1529,8 @@ def FuseReshapesAroundBatchMatMulLHS1: Pat< $rhs, $adj_x, $adj_y, $bool_attr), (Arith_ConstantOp $s1)), (TFL_BatchMatMulOp $input, $rhs, $adj_x, $adj_y, $bool_attr), - [(HasRank<3> $rhs), + [(HasRankAtLeast<3> $input), + (HasRank<3> $rhs), (HasRank<3> $initial_shape_change), (IsBroadcastDimEqualToOne $rhs), (IsBroadcastDimEqualToOne $input), diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.td b/tensorflow/compiler/mlir/lite/transforms/passes.td index 988ad189a6ec00..eefb109d2b966e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/transforms/passes.td @@ -108,7 +108,7 @@ def LegalizeHashTablesPass : Pass<"tfl-legalize-hashtables-tf", "mlir::ModuleOp" def LegalizeJaxRandomPass : Pass<"tfl-legalize-random", "mlir::func::FuncOp"> { let summary = "Replace jax.random.uniform/normal with tfl.custom."; let constructor = "CreateLegalizeJaxRandomPass()"; - let dependentDialects = ["TFL::TensorFlowLiteDialect"]; + let dependentDialects = ["TFL::TensorFlowLiteDialect", "stablehlo::StablehloDialect"]; } def LegalizeTFPass : Pass<"tfl-legalize-tf", "mlir::func::FuncOp"> { diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 6694db441b6566..ce11ca73970136 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -43,7 +43,6 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h index da5a941179deb8..e102c6bedd4328 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h @@ -38,8 +38,8 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 41a3144da6fe87..9f0a7fbafff450 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -59,9 +59,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc b/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc index 33580d1ea95dbc..0d9db051ef27ff 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc index 5ce7638f4e4da1..96d75cca30a48d 100644 --- a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc @@ -345,22 +345,41 @@ StatusOr ConvertFloatBuffer( switch (elem_type.getIntOrFloatBitWidth()) { case 16: { assert(bytes_len % 2 == 0); - assert(elem_type.isF16()); + // Supports both BF16 and F16. + assert(elem_type.isF16() || elem_type.isBF16()); int elem_count = bytes_len / 2; - std::vector values; - values.reserve(elem_count); - const char* data = reinterpret_cast(buffer.data()); + if (elem_type.isF16()) { + std::vector values; + values.reserve(elem_count); - for (int i = 0; i < elem_count; i++) { - uint16_t bit_repr = - llvm::support::endian::readNext(data); - values.push_back(Eigen::numext::bit_cast(bit_repr)); - } + const char* data = reinterpret_cast(buffer.data()); - return mlir::ElementsAttr( - DenseElementsAttr::get(shaped_type, ArrayRef(values))); + for (int i = 0; i < elem_count; i++) { + uint16_t bit_repr = llvm::support::endian::readNext< + uint16_t, llvm::endianness::native, llvm::support::unaligned>( + data); + values.push_back(Eigen::numext::bit_cast(bit_repr)); + } + + return mlir::ElementsAttr( + DenseElementsAttr::get(shaped_type, ArrayRef(values))); + } else { + std::vector values; + values.reserve(elem_count); + + const char* data = reinterpret_cast(buffer.data()); + + for (int i = 0; i < elem_count; i++) { + uint16_t bit_repr = llvm::support::endian::readNext< + uint16_t, llvm::endianness::native, llvm::support::unaligned>( + data); + values.push_back(Eigen::numext::bit_cast(bit_repr)); + } + + return mlir::ElementsAttr(DenseElementsAttr::get( + shaped_type, ArrayRef(values))); + } } case 32: { assert(bytes_len % 4 == 0); diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index 9b215e77b89529..e09030ceb7515f 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -34,6 +34,8 @@ namespace errors = tensorflow::errors; tflite::TensorType ConvertTypeToTensorType(mlir::Type type) { if (type.isF16()) { return tflite::TensorType_FLOAT16; + } else if (type.isBF16()) { + return tflite::TensorType_BFLOAT16; } else if (type.isF32()) { return tflite::TensorType_FLOAT32; } else if (type.isF64()) { @@ -81,6 +83,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) { switch (type) { case tflite::TensorType_FLOAT16: return builder.getF16Type(); + case tflite::TensorType_BFLOAT16: + return builder.getBF16Type(); case tflite::TensorType_FLOAT32: return builder.getF32Type(); case tflite::TensorType_FLOAT64: @@ -128,6 +132,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) { return tensorflow::DT_COMPLEX128; case tflite::TensorType_FLOAT16: return tensorflow::DT_HALF; + case tflite::TensorType_BFLOAT16: + return tensorflow::DT_BFLOAT16; case tflite::TensorType_FLOAT32: return tensorflow::DT_FLOAT; case tflite::TensorType_FLOAT64: @@ -170,6 +176,8 @@ absl::StatusOr TfTypeToTflType(tensorflow::DataType type) { return tflite::TensorType_COMPLEX128; case tensorflow::DT_HALF: return tflite::TensorType_FLOAT16; + case tensorflow::DT_BFLOAT16: + return tflite::TensorType_BFLOAT16; case tensorflow::DT_FLOAT: return tflite::TensorType_FLOAT32; case tensorflow::DT_DOUBLE: diff --git a/tensorflow/compiler/mlir/quantization/common/BUILD b/tensorflow/compiler/mlir/quantization/common/BUILD index 8e4c39b8d5f1b7..8091fe21ef56ff 100644 --- a/tensorflow/compiler/mlir/quantization/common/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/BUILD @@ -41,6 +41,7 @@ cc_library( "//tensorflow/core:framework_lite", "//tensorflow/core/ir/types:Dialect", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -48,6 +49,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:protobuf", ], ) @@ -60,6 +62,7 @@ tf_cc_test( ":test_base", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -68,6 +71,7 @@ tf_cc_test( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status_matchers", "@stablehlo//:stablehlo_ops", ], @@ -109,6 +113,7 @@ cc_library( hdrs = ["test_base.h"], compatible_with = get_compatible_with_portable(), deps = [ + ":func", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:context", @@ -122,6 +127,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", "@stablehlo//:stablehlo_ops", ], ) diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h index 42ecca536f54a5..852902e229a9fc 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_ATTRS_AND_CONSTRAINTS_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_ATTRS_AND_CONSTRAINTS_H_ +#include #include #include #include @@ -40,10 +41,19 @@ namespace mlir::quant { constexpr char kAttrMapAttribute[] = "attr_map"; -// TODO: b/238829558 - Populate quantization config based on the -// QuantizationOptions proto. -// TODO: b/263449239 - Put the OpSet aliases separately within each file -using OpSet = tensorflow::quantization::OpSet; +// Permutation from the NHWC tensor format to NCHW. This is an inverse +// permutation of `kNchwToNhwcPermutation`. +inline constexpr std::array kNhwcToNchwPermutation = {0, 3, 1, 2}; + +// Permutation from the NCHW tensor format to NHWC. This is an inverse +// permutation of `kNchwToNhwcPermutation`. +inline constexpr std::array kNchwToNhwcPermutation = {0, 2, 3, 1}; + +// Permutation from the OIHW (== (output features, input features, height, +// width)) tensor format to HWIO. This is commonly used to transpose convolution +// weights represented as OIHW format to HWIO, which is more desirable for +// certain downstream optimization passes (e.g. XLA). +inline constexpr std::array kOihwToHwioPermutation = {2, 3, 1, 0}; // Returns true if the value has static shape. bool HasStaticShape(Value value); diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc index 6ec7285a8e7406..f6e633aa4c7861 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc @@ -72,7 +72,7 @@ constexpr absl::string_view kModuleMultipleUses = R"mlir( module { func.func @main(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> - %1 = stablehlo.subtract %0, %arg2 : tensor<1x3xf32> + %1 = stablehlo.subtract %arg2, %0 : tensor<1x3xf32> %2 = stablehlo.add %0, %arg2 : tensor<1x3xf32> return %2 : tensor<1x3xf32> } @@ -411,9 +411,8 @@ TEST_F(AttrsAndConstraintsTest, HasQuantizableTraitFalse) { } TEST_F(AttrsAndConstraintsTest, IsHybridQuantizedOpTrue) { - OwningOpRef module_op_ref = - ParseModuleOpString(kModuleHybridQuantized); - func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + OwningOpRef module_op = ParseModuleOpString(kModuleHybridQuantized); + func::FuncOp main_fn = FindMainFuncOp(*module_op); ASSERT_THAT(main_fn, NotNull()); Operation* dot_general = FindOperationOfType(main_fn); @@ -421,9 +420,8 @@ TEST_F(AttrsAndConstraintsTest, IsHybridQuantizedOpTrue) { } TEST_F(AttrsAndConstraintsTest, IsHybridQuantizedOpFalse) { - OwningOpRef module_op_ref = - ParseModuleOpString(kModuleXlaCallModule); - func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + OwningOpRef module_op = ParseModuleOpString(kModuleXlaCallModule); + func::FuncOp main_fn = FindMainFuncOp(*module_op); ASSERT_THAT(main_fn, NotNull()); Operation* call_op = FindOperationOfType(main_fn); @@ -453,17 +451,25 @@ constexpr absl::string_view kModuleDotGeneralBatchMatmul = R"mlir( )mlir"; TEST_F(AttrsAndConstraintsTest, DotGeneralFullyConnectedReturnsQuantDim) { - OwningOpRef module_op_ref = + OwningOpRef module_op = ParseModuleOpString(kModuleDotGeneralFullyConnected); - func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + auto dot_general_op = *main_fn.getOps().begin(); EXPECT_THAT(GetDotGeneralQuantizationDim(dot_general_op), Optional(1)); } TEST_F(AttrsAndConstraintsTest, DotGeneralBatchMatmulReturnsNullQuantDim) { - OwningOpRef module_op_ref = + OwningOpRef module_op = ParseModuleOpString(kModuleDotGeneralBatchMatmul); - func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + auto dot_general_op = *main_fn.getOps().begin(); EXPECT_THAT(GetDotGeneralQuantizationDim(dot_general_op), Eq(std::nullopt)); } diff --git a/tensorflow/compiler/mlir/quantization/common/ir/BUILD b/tensorflow/compiler/mlir/quantization/common/ir/BUILD index 1f62dff9711d80..615f54f70d2373 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/ir/BUILD @@ -57,15 +57,24 @@ gentbl_cc_library( cc_library( name = "QuantOps", srcs = [ + "FakeQuantSupport.cc", "QuantOps.cc", + "UniformSupport.cc", + ], + hdrs = [ + "FakeQuantSupport.h", + "QuantOps.h", + "UniformSupport.h", ], - hdrs = ["QuantOps.h"], compatible_with = get_compatible_with_portable(), deps = [ ":QuantOpsIncGen", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.cc b/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.cc similarity index 88% rename from tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.cc rename to tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.cc index 9b662ebdca8461..292e0eeb3cce71 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.cc +++ b/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.cc @@ -13,12 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" + +#include +#include +#include +#include +#include #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project -using namespace mlir; -using namespace mlir::quantfork; +namespace mlir::quantfork { static bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned, MLIRContext *ctx, @@ -121,9 +131,11 @@ static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, assert(nudgedZeroPoint <= qmax); } -quant::UniformQuantizedType mlir::quantfork::fakeQuantAttrsToType( - Location loc, unsigned numBits, double rmin, double rmax, bool narrowRange, - Type expressedType, bool isSigned) { +quant::UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, + double rmin, double rmax, + bool narrowRange, + Type expressedType, + bool isSigned) { MLIRContext *ctx = expressedType.getContext(); unsigned flags = isSigned ? quant::QuantizationFlags::Signed : 0; Type storageType; @@ -152,7 +164,7 @@ quant::UniformQuantizedType mlir::quantfork::fakeQuantAttrsToType( nudgedZeroPoint, qmin, qmax); } -quant::UniformQuantizedPerAxisType mlir::quantfork::fakeQuantAttrsToType( +quant::UniformQuantizedPerAxisType fakeQuantAttrsToType( Location loc, unsigned numBits, int32_t quantizedDimension, ArrayRef rmins, ArrayRef rmaxs, bool narrowRange, Type expressedType, bool isSigned) { @@ -198,3 +210,5 @@ quant::UniformQuantizedPerAxisType mlir::quantfork::fakeQuantAttrsToType( loc, flags, storageType, expressedType, scales, zeroPoints, quantizedDimension, qmin, qmax); } + +} // namespace mlir::quantfork diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h b/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h similarity index 93% rename from tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h rename to tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h index 6072172eaebe38..335f80782a5e20 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h +++ b/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h @@ -41,8 +41,8 @@ limitations under the License. // //===----------------------------------------------------------------------===// -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_FAKEQUANTSUPPORT_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_FAKEQUANTSUPPORT_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_FAKEQUANTSUPPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_FAKEQUANTSUPPORT_H_ #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project @@ -71,4 +71,4 @@ quant::UniformQuantizedPerAxisType fakeQuantAttrsToType( } // namespace quantfork } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_FAKEQUANTSUPPORT_H_ +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_FAKEQUANTSUPPORT_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.cc b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc similarity index 97% rename from tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.cc rename to tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc index e5c3dd35a27981..5a200241af00dd 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.cc +++ b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" #include diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h similarity index 97% rename from tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h rename to tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h index 064afb0b36aa13..b6f65e455d0c09 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h +++ b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_UNIFORMSUPPORT_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_UNIFORMSUPPORT_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_UNIFORMSUPPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_UNIFORMSUPPORT_H_ #include @@ -237,4 +237,4 @@ class UniformQuantizedPerAxisValueConverter { } // namespace quantfork } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_UNIFORMSUPPORT_H_ +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_UNIFORMSUPPORT_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc index 9c700ed50bc4d0..050bf45d7b5a46 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" @@ -69,30 +70,34 @@ constexpr int64_t kDefaultVersion = 9; constexpr StringRef kPlatformCpu = "CPU"; // Name of `tf.XlaCallModule`'s dictionary attribute for keeping the // deserialized stablehlo module's attributes. -constexpr llvm::StringRef kStablehloModuleAttrsAttrName = - "_stablehlo_module_attrs"; +constexpr StringRef kStablehloModuleAttrsAttrName = "_stablehlo_module_attrs"; // Attribute required for running shape refinement pass enabled in XlaCallModule // version 8 and above. -constexpr llvm::StringRef kUsesShapePolymorphismAttr = - "jax.uses_shape_polymorphism"; +constexpr StringRef kUsesShapePolymorphismAttr = "jax.uses_shape_polymorphism"; -// Checks if the op is inside a lifted function. -bool IsInLiftedFunc(Operation& op) { - return op.getParentOfType()->hasAttr(kFusedFunctionAttr); +bool IsInLiftedFunc(Operation* op) { + if (op == nullptr) return false; + return op->getParentOfType()->hasAttr(kFusedFunctionAttr); +} + +bool IsInStableHloOpRegion(Operation* op) { + if (op == nullptr) return false; + auto parent_op = op->getParentOp(); + return parent_op != nullptr && stablehlo::IsStablehloOp(parent_op); } // Inserts the function to the symbol table of the module thread-safely. StringAttr InsertToSymbolTable(Operation& module, Operation& function, - const std::string& func_name) { + const StringRef func_name) { static tensorflow::mutex* mtx = new tensorflow::mutex(); tensorflow::mutex_lock lock(*mtx); SymbolTable symbol_table(&module); - std::string unique_name = func_name; + std::string unique_name = func_name.str(); int32_t uniquing_counter = 0; while (symbol_table.lookup(unique_name) != nullptr) { ++uniquing_counter; - unique_name = func_name + "_" + std::to_string(uniquing_counter); + unique_name = absl::StrCat(func_name.str(), "_", uniquing_counter); } function.setAttr("sym_name", StringAttr::get(module.getContext(), unique_name)); @@ -101,9 +106,11 @@ StringAttr InsertToSymbolTable(Operation& module, Operation& function, // Creates the TF::PartitionedCallOp with the given arguments and output types. // This function call op is for invoking the TF subgraphs. -ValueRange createTFPartitionedCallOp(OpBuilder builder, Location location, - StringRef func_name, - TypeRange output_types, ValueRange args) { +ValueRange CreateTFPartitionedCallOp(OpBuilder& builder, + const Location location, + const StringRef func_name, + const TypeRange output_types, + const ValueRange args) { TF::PartitionedCallOp call_op = builder.create( location, output_types, args, FlatSymbolRefAttr::get(builder.getStringAttr(func_name)), @@ -112,7 +119,7 @@ ValueRange createTFPartitionedCallOp(OpBuilder builder, Location location, // Set the attribute to annotate this function call op as a quantizable spot. call_op->setAttr( kQuantTraitAttrName, - builder.getStringAttr(llvm::StringRef( + builder.getStringAttr(StringRef( std::string(QuantTraitValues[QuantizationTrait::FullyQuantizable])))); return call_op.getOutput(); @@ -120,10 +127,11 @@ ValueRange createTFPartitionedCallOp(OpBuilder builder, Location location, // Creates the TF::XlaCallModuleOp with the given arguments and output types. // This function call op is for invoking the StableHLO subgraphs. -ValueRange createTFXlaCallModuleOp(OpBuilder builder, Location location, - StringRef func_name, TypeRange output_types, - ValueRange args) { - auto ctx = builder.getContext(); +ValueRange CreateTFXlaCallModuleOp(OpBuilder& builder, const Location location, + const StringRef func_name, + const TypeRange output_types, + const ValueRange args) { + MLIRContext* ctx = builder.getContext(); // Collect the shapes of the output to fill up the Sout attribute. SmallVector shape_attrs; for (const Type result_type : output_types) { @@ -133,7 +141,7 @@ ValueRange createTFXlaCallModuleOp(OpBuilder builder, Location location, auto empty_array_attr = ArrayAttr::get(ctx, {}); auto platforms = ArrayAttr::get(ctx, {StringAttr::get(ctx, kPlatformCpu)}); - TF::XlaCallModuleOp call_op = builder.create( + auto call_op = builder.create( location, /*output=*/output_types, /*args=*/args, @@ -159,7 +167,7 @@ ValueRange createTFXlaCallModuleOp(OpBuilder builder, Location location, // Set the attribute to annotate this function call op as a quantizable spot. call_op->setAttr( kQuantTraitAttrName, - builder.getStringAttr(llvm::StringRef( + builder.getStringAttr(StringRef( std::string(QuantTraitValues[QuantizationTrait::FullyQuantizable])))); // Set jax.uses_shape_polymorphism=true to enable shape refinement at runtime. @@ -172,27 +180,25 @@ ValueRange createTFXlaCallModuleOp(OpBuilder builder, Location location, } // Creates the function call op based on the given call_op_type argument. -ValueRange createFunctionCallOp(OpBuilder builder, Location location, - FunctionCallOpType call_op_type, - StringRef func_name, TypeRange output_types, - ValueRange args) { +ValueRange CreateFunctionCallOp(OpBuilder& builder, const Location location, + const FunctionCallOpType call_op_type, + const StringRef func_name, + const TypeRange output_types, + const ValueRange args) { switch (call_op_type) { case FunctionCallOpType::TFXlaCallModuleOp: - return createTFXlaCallModuleOp(builder, location, func_name, output_types, + return CreateTFXlaCallModuleOp(builder, location, func_name, output_types, args); case FunctionCallOpType::TFPartitionedCallOp: - return createTFPartitionedCallOp(builder, location, func_name, + return CreateTFPartitionedCallOp(builder, location, func_name, output_types, args); - default: - llvm_unreachable("unhandled call op type"); } } // Finds ops in the paths from arguments to results. The ops is listed in an // order that the former ops shouldn't have any dependencies on the later ones. -llvm::SmallVector FindOpsFromArgumentsToResults( - const llvm::SmallVector& arguments, - const llvm::SmallVector& results) { +SmallVector FindOpsFromArgumentsToResults( + const ArrayRef arguments, const ArrayRef results) { std::queue value_queue; for (Value result : results) { value_queue.push(result); @@ -213,7 +219,7 @@ llvm::SmallVector FindOpsFromArgumentsToResults( Operation* defining_node = current_value.getDefiningOp(); if (defining_node == nullptr) continue; op_stack.push(defining_node); - for (const auto& arg : defining_node->getOperands()) { + for (Value arg : defining_node->getOperands()) { if (!argument_set.contains(arg.getImpl())) { value_queue.push(arg); } @@ -221,7 +227,7 @@ llvm::SmallVector FindOpsFromArgumentsToResults( } // Remove duplicate ops from the op stack. - llvm::SmallVector sorted_ops; + SmallVector sorted_ops; absl::flat_hash_set unique_ops; while (!op_stack.empty()) { Operation* current_op = op_stack.top(); @@ -243,9 +249,9 @@ llvm::SmallVector FindOpsFromArgumentsToResults( // "0:transpose_a,1:transpose_b", where 0 and 1 are the respective attribute // identifiers. // This function returns success if all attributes could be found. -LogicalResult SetAttributeMap( - MLIRContext& context, const llvm::SmallVector& attributes, - const llvm::SmallVector& ops) { +LogicalResult SetAttributeMap(MLIRContext& context, + const ArrayRef attributes, + const ArrayRef ops) { // A map to find which operation an attribute belongs to. // The key for this map uses the entire NamedAttribute object, i.e. the // {attribute_name, attribute_value} pair. @@ -270,8 +276,8 @@ LogicalResult SetAttributeMap( attr_to_op_map.begin(), attr_to_op_map.end(), [&](auto attr_op) { return std::get<0>(attr_op).getName() == attribute.getName(); }) == attr_to_op_map.end()) { - mlir::emitError(UnknownLoc::get(&context), - "Could not find attribute: " + attribute.getName().str()); + emitError(UnknownLoc::get(&context), + "Could not find attribute: " + attribute.getName().str()); return failure(); } @@ -293,7 +299,7 @@ LogicalResult SetAttributeMap( // Append ":". Ex) "0:transpose_a". const std::string identifier = std::to_string(idx); - const mlir::StringAttr attribute_name = attribute.getName(); + const StringAttr attribute_name = attribute.getName(); absl::StrAppend(&new_attr_map_str, identifier, ":", attribute_name.str()); owner_op->setAttr(kAttrMapAttribute, StringAttr::get(&context, new_attr_map_str)); @@ -303,14 +309,14 @@ LogicalResult SetAttributeMap( } // Creates a function to wrap the section between arguments and results. -llvm::SmallVector LiftAsFunctionCall( - OpBuilder builder, Location location, FunctionCallOpType call_op_type, - StringRef func_name, const llvm::SmallVector& arguments, - const llvm::SmallVector& results, - const llvm::SmallVector& attributes) { +SmallVector LiftAsFunctionCall( + OpBuilder& builder, const Location location, + const FunctionCallOpType call_op_type, const StringRef func_name, + const ArrayRef arguments, const ArrayRef results, + const ArrayRef attributes) { MLIRContext* context = builder.getContext(); if (results.empty()) { - mlir::emitError(UnknownLoc::get(context), "No result values specified"); + emitError(UnknownLoc::get(context), "No result values specified"); return {}; } Operation* result_op = results[0].getDefiningOp(); @@ -324,10 +330,11 @@ llvm::SmallVector LiftAsFunctionCall( TypeRange result_types{ValueRange{results}}; auto func_type = FunctionType::get(context, arg_types, result_types); - llvm::SmallVector arg_locs; - for (const auto& arg : arguments) { + SmallVector arg_locs; + for (Value arg : arguments) { arg_locs.push_back(arg.getLoc()); } + auto wrap_func = builder.create(location, func_name, func_type); wrap_func.setVisibility(SymbolTable::Visibility::Private); // The callee function for TF::XlaCallModuleOp must have this attribute. @@ -361,34 +368,36 @@ llvm::SmallVector LiftAsFunctionCall( builder.clone(*op, mapping); } - llvm::SmallVector return_values; + SmallVector return_values; for (Value result : results) { return_values.push_back(mapping.lookupOrNull(result)); } - builder.create(location, return_values); + builder.create(location, return_values); // Create a function call to the newly created function. StringAttr new_func_name = - InsertToSymbolTable(*module, *wrap_func, func_name.str()); + InsertToSymbolTable(*module, *wrap_func, func_name); builder.setInsertionPointAfter(result_op); ValueRange new_results = - createFunctionCallOp(builder, call_op_loc, call_op_type, + CreateFunctionCallOp(builder, call_op_loc, call_op_type, new_func_name.getValue(), result_types, arguments); - return llvm::SmallVector(new_results.begin(), new_results.end()); + return SmallVector(new_results.begin(), new_results.end()); } -llvm::SmallVector LiftAsFunctionCall( - OpBuilder builder, Location location, FunctionCallOpType call_op_type, - StringRef func_name, const llvm::SmallVector& arguments, - const llvm::SmallVector& results) { - llvm::SmallVector attributes; +SmallVector LiftAsFunctionCall(OpBuilder& builder, + const Location location, + const FunctionCallOpType call_op_type, + const StringRef func_name, + const ArrayRef arguments, + const ArrayRef results) { + SmallVector attributes; return LiftAsFunctionCall(builder, location, call_op_type, func_name, arguments, results, attributes); } -llvm::SmallVector AppendToVector( - const llvm::SmallVector& arguments, Value append) { - llvm::SmallVector ret(arguments); +SmallVector AppendToVector(const ArrayRef arguments, + Value append) { + SmallVector ret(arguments); ret.push_back(append); return ret; } @@ -402,7 +411,7 @@ llvm::SmallVector AppendToVector( // could process the following equation by setting the attributes properly: // abc,cd->abd. // 4. The output should be in the form: [batch dims][lhs dims][rhs dims] -bool IsEinsumSupportedByXlaDotV2(mlir::StringAttr equation_attr) { +bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr) { StringRef equation = equation_attr.getValue(); if (!absl::StrContains(equation, "->") || !absl::StrContains(equation, ",") || @@ -489,4 +498,15 @@ absl::StatusOr GetQuantizationMethod( return quantization_method; } +Method GetQuantizationMethodOrDefault(TF::XlaCallModuleOp xla_call_module_op) { + absl::StatusOr method = GetQuantizationMethod(xla_call_module_op); + if (method.status().code() == absl::StatusCode::kInternal) { + // This indicates that the `Method` protobuf string is corrupt, but this + // function ignores it and returns the default instance. + xla_call_module_op->emitError(absl::StrCat( + "Failed to get quantization method: ", method.status().ToString())); + } + return method.ok() ? *method : Method::default_instance(); +} + } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h index f2edd732f50cc5..bd7421d376102b 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h @@ -47,11 +47,16 @@ inline constexpr StringRef kQuantizationMethodAttr = "_quantization_method"; // function lifting will happen. enum FunctionCallOpType { TFPartitionedCallOp = 0, TFXlaCallModuleOp = 1 }; -// Checks if the op is inside a lifted function. -bool IsInLiftedFunc(Operation &op); +// Checks if an op is inside a lifted function. +// If the given op pointer is a nullptr, returns false. +bool IsInLiftedFunc(Operation* op); -// Checks if the given einsum op is supported for XlaDotV2 quantization. -bool IsEinsumSupportedByXlaDotV2(mlir::StringAttr equation_attr); +// Checks if the op is inside a StableHLO op with region. +// If the given op pointer is a nullptr, returns false. +bool IsInStableHloOpRegion(Operation* op); + +// Checks if a given einsum op is supported for XlaDotV2 quantization. +bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr); // Gets the quantization method from the given `XlaCallModuleOp`. It is // retrieved from the `kQuantizationMethodAttr` string attribute. Returns @@ -60,27 +65,35 @@ bool IsEinsumSupportedByXlaDotV2(mlir::StringAttr equation_attr); absl::StatusOr<::stablehlo::quantization::Method> GetQuantizationMethod( TF::XlaCallModuleOp xla_call_module_op); +// Gets the quantization method from the given `XlaCallModuleOp`. It is +// retrieved from the `kQuantizationMethodAttr` string attribute. Returns a +// default instance of `Method` iff the attribute doesn't exist or the attribute +// contains an invalid textproto for `Method`. +::stablehlo::quantization::Method GetQuantizationMethodOrDefault( + TF::XlaCallModuleOp xla_call_module_op); + // Creates a function to wrap the section between arguments and results. // The generated function call op type will be decided by the given call_op_type // argument. Currently, it supports TF::XlaCallModuleOp and // TF::PartitionedCallOp function call op generations. -llvm::SmallVector LiftAsFunctionCall( - OpBuilder builder, Location location, FunctionCallOpType call_op_type, - StringRef func_name, const llvm::SmallVector &arguments, - const llvm::SmallVector &results, - const llvm::SmallVector &attributes); +SmallVector LiftAsFunctionCall(OpBuilder& builder, Location location, + FunctionCallOpType call_op_type, + StringRef func_name, + ArrayRef arguments, + ArrayRef results, + ArrayRef attributes); // Same as above but with empty attributes. -llvm::SmallVector LiftAsFunctionCall( - OpBuilder builder, Location location, FunctionCallOpType call_op_type, - StringRef func_name, const llvm::SmallVector &arguments, - const llvm::SmallVector &results); +SmallVector LiftAsFunctionCall(OpBuilder& builder, Location location, + FunctionCallOpType call_op_type, + StringRef func_name, + ArrayRef arguments, + ArrayRef results); // Add the second argument to the first argument, which is expected to be an // argument list. // Used to attach bias to einsum argument list. -llvm::SmallVector AppendToVector( - const llvm::SmallVector &arguments, Value append); +SmallVector AppendToVector(ArrayRef arguments, Value append); } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td index a4437b50ac0cf0..1ca03a803bef4d 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td @@ -59,7 +59,11 @@ class NamedAttr : // Checks if the value is not defined inside a lifted function by checking the // `tf_quant.composite_function` attribute. def IsNotInLiftedFunc : - Constraint>; + Constraint>; + +// Checks if the value is not inside a StableHLO op with region. +def IsNotInStableHloOpRegion : + Constraint>; // Checks if the given einsum op is supported for XlaDotV2 quantization. def IsEinsumSupportedByXlaDotV2 : diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc index 3d1285928f5f18..c37a997217d2b7 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -30,11 +31,13 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/common/test_base.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/platform/status_matchers.h" namespace mlir::quant { @@ -43,10 +46,11 @@ namespace { using ::stablehlo::quantization::Method; using ::testing::HasSubstr; using ::testing::NotNull; +using ::tsl::protobuf::util::MessageDifferencer; using ::tsl::testing::IsOk; using ::tsl::testing::StatusIs; -using LiftAsFunctionCallTest = ::mlir::quant::QuantizationTestBase; +using LiftAsFunctionCallTest = QuantizationTestBase; constexpr absl::string_view kModuleLifted = R"mlir( module { @@ -65,10 +69,9 @@ TEST_F(LiftAsFunctionCallTest, LiftedFunctionSucceeds) { module_op->lookupSymbol("composite_dot_general_fn_1"); ASSERT_THAT(composite_dot_general_fn, NotNull()); - Operation* dot_general_op = - FindOperationOfType( - composite_dot_general_fn); - EXPECT_TRUE(IsInLiftedFunc(*dot_general_op)); + auto dot_general_op = FindOperationOfType( + composite_dot_general_fn); + EXPECT_TRUE(IsInLiftedFunc(dot_general_op)); } constexpr absl::string_view kModuleStableHlo = R"mlir( @@ -87,7 +90,7 @@ TEST_F(LiftAsFunctionCallTest, FunctionLiftedAsXlaCallModuleOp) { func::FuncOp main_fn = FindMainFuncOp(*module_op); ASSERT_THAT(main_fn, NotNull()); - Operation* dot_general_op = + auto dot_general_op = FindOperationOfType(main_fn); const SmallVector& attributes = { @@ -97,19 +100,20 @@ TEST_F(LiftAsFunctionCallTest, FunctionLiftedAsXlaCallModuleOp) { 1, mlir::stablehlo::PrecisionAttr::get( ctx_.get(), mlir::stablehlo::Precision::DEFAULT)))), }; + const SmallVector operands(dot_general_op->getOperands()); + const SmallVector results(dot_general_op->getResults()); Operation* lifted_op = LiftAsFunctionCall(builder_, dot_general_op->getLoc(), FunctionCallOpType::TFXlaCallModuleOp, - "composite_dot_general_fn", - dot_general_op->getOperands(), - dot_general_op->getResults(), attributes)[0] + "composite_dot_general_fn", operands, results, + attributes)[0] .getDefiningOp(); const auto entry_function_symbol_ref = lifted_op->getAttrOfType("_entry_function"); SymbolTable symbol_table(*module_op); auto entry_func = dyn_cast_or_null( symbol_table.lookup(entry_function_symbol_ref.getValue())); - Operation* lifted_dot_general_op = + auto lifted_dot_general_op = FindOperationOfType(entry_func); EXPECT_TRUE(isa(lifted_op)); @@ -129,13 +133,14 @@ TEST_F(LiftAsFunctionCallTest, FunctionNoAttrLiftedAsXlaCallModuleOp) { func::FuncOp main_fn = FindMainFuncOp(*module_op); ASSERT_THAT(main_fn, NotNull()); - Operation* dot_general_op = + auto dot_general_op = FindOperationOfType(main_fn); + const SmallVector operands(dot_general_op->getOperands()); + const SmallVector results(dot_general_op->getResults()); Operation* lifted_op = - LiftAsFunctionCall( - builder_, dot_general_op->getLoc(), - FunctionCallOpType::TFXlaCallModuleOp, "composite_dot_general_fn", - dot_general_op->getOperands(), dot_general_op->getResults())[0] + LiftAsFunctionCall(builder_, dot_general_op->getLoc(), + FunctionCallOpType::TFXlaCallModuleOp, + "composite_dot_general_fn", operands, results)[0] .getDefiningOp(); EXPECT_TRUE(isa(lifted_op)); EXPECT_EQ(lifted_op->getAttr("_original_entry_function").cast(), @@ -242,5 +247,109 @@ TEST_F(LiftAsFunctionCallTest, HasSubstr("Failed to parse Method from textproto"))); } +constexpr absl::string_view kFunctionWithRegion = + R"mlir( + func.func @main(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %if = "stablehlo.if"(%arg0) ({ + %0 = stablehlo.add %arg1, %arg1 : tensor + stablehlo.return %0 : tensor + }, { + %1 = stablehlo.add %arg2, %arg2 : tensor + stablehlo.return %1 : tensor + }) : (tensor) -> (tensor) + %subtract = stablehlo.subtract %if, %if : tensor + return %subtract : tensor + } +)mlir"; + +TEST_F(LiftAsFunctionCallTest, IsInRegionSucceedsWhenOpInsideRegion) { + const OwningOpRef module_op = + ParseModuleOpString(kFunctionWithRegion); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto if_op = FindOperationOfType(main_fn); + Block& block = if_op->getRegion(0).front(); + Operation& add_op = *absl::c_find_if(block, [](Operation& entry) { + return dyn_cast_or_null<::mlir::stablehlo::AddOp>(&entry); + }); + EXPECT_TRUE(IsInStableHloOpRegion(&add_op)); +} + +TEST_F(LiftAsFunctionCallTest, IsInRegionFailsWhenOpNotInsideRegion) { + const OwningOpRef module_op = + ParseModuleOpString(kFunctionWithRegion); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto subtract_op = FindOperationOfType(main_fn); + EXPECT_FALSE(IsInStableHloOpRegion(subtract_op)); +} + +TEST_F(LiftAsFunctionCallTest, + GetQuantizationMethodOrDefaultReturnsCorrectMethod) { + // Function containing a simple `TF::XlaCallModuleOp` with a valid string + // attribute `_quantization_method` set to `"no_quantization { }"`. + constexpr absl::string_view kXlaCallModuleOpWithQuantizationMethodAttr = + R"mlir( + func.func @main(%arg0: tensor<1x1x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<1x1x4xf32> { + %0 = "tf.XlaCallModule"(%arg0, %arg1) <{Sout = [#tf_type.shape<1x1x4>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> + { + _entry_function = @composite_dot_general_fn_1, + _quantization_method = "no_quantization { }", + _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true} + } : (tensor<1x1x3xf32>, tensor<3x4xf32>) -> tensor<1x1x4xf32> + return %0 : tensor<1x1x4xf32> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kXlaCallModuleOpWithQuantizationMethodAttr); + ASSERT_TRUE(module_op); + + FailureOr xla_call_module_op = + FindFirstOpFromMainFunc(*module_op); + ASSERT_TRUE(succeeded(xla_call_module_op)); + + // Test that `GetQuantizationMethodOrDefault` returns a valid `Method` + // corresponding to `"no_quantization {}"`. + const Method method = GetQuantizationMethodOrDefault(*xla_call_module_op); + EXPECT_TRUE(method.has_no_quantization()); +} + +TEST_F( + LiftAsFunctionCallTest, + GetQuantizationMethodOrDefaultReturnsDefaultWhenNoQuantizationMethodAttr) { + // Function containing a simple `TF::XlaCallModuleOp` that is missing the + // "_quantization_method" attribute. + constexpr absl::string_view kXlaCallModuleOpWithoutQuantizationMethodAttr = + R"mlir( + func.func @main(%arg0: tensor<1x1x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<1x1x4xf32> { + %0 = "tf.XlaCallModule"(%arg0, %arg1) <{Sout = [#tf_type.shape<1x1x4>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> + { + _entry_function = @composite_dot_general_fn_1, + _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true} + } : (tensor<1x1x3xf32>, tensor<3x4xf32>) -> tensor<1x1x4xf32> + return %0 : tensor<1x1x4xf32> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kXlaCallModuleOpWithoutQuantizationMethodAttr); + ASSERT_TRUE(module_op); + + FailureOr xla_call_module_op = + FindFirstOpFromMainFunc(*module_op); + ASSERT_TRUE(succeeded(xla_call_module_op)); + + // Test that `GetQuantizationMethodOrDefault` returns the default instance. + const Method method = GetQuantizationMethodOrDefault(*xla_call_module_op); + EXPECT_TRUE(MessageDifferencer::Equals(method, Method::default_instance())); +} + } // namespace } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/python/testing.py b/tensorflow/compiler/mlir/quantization/common/python/testing.py index 78eb2409c70f89..211e08df7d9e4b 100644 --- a/tensorflow/compiler/mlir/quantization/common/python/testing.py +++ b/tensorflow/compiler/mlir/quantization/common/python/testing.py @@ -1,5 +1,3 @@ -"""Common testing utilities for quantization libraries.""" - # Copyright 2024 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Common testing utilities for quantization libraries.""" import itertools +import os from typing import Any, Mapping, Sequence @@ -36,3 +36,34 @@ def parameter_combinations( for curr in itertools.product(*parameters.values()): real_parameters.append(dict(zip(keys, curr))) return real_parameters + + +def get_dir_size(path: str = '.') -> int: + """Get the total size of files and sub-directories under the path. + + Args: + path: Path of a directory or a file to calculate the total size. + + Returns: + Total size of the directory or a file. + """ + total = 0 + for root, _, files in os.walk(path): + for filename in files: + total += os.path.getsize(os.path.join(root, filename)) + return total + + +def get_size_ratio(path_a: str, path_b: str) -> float: + """Return the size ratio of the given paths. + + Args: + path_a: Path of a directory or a file to be the nominator of the ratio. + path_b: Path of a directory or a file to be the denominator of the ratio. + + Returns: + Ratio of size of path_a / size of path_b. + """ + size_a = get_dir_size(path_a) + size_b = get_dir_size(path_b) + return size_a / size_b diff --git a/tensorflow/compiler/mlir/quantization/common/python/testing_test.py b/tensorflow/compiler/mlir/quantization/common/python/testing_test.py index 3366959456d5fe..9549e10898cedb 100644 --- a/tensorflow/compiler/mlir/quantization/common/python/testing_test.py +++ b/tensorflow/compiler/mlir/quantization/common/python/testing_test.py @@ -37,5 +37,27 @@ def test_parameter_combinations(self): self.assertIn({'shapes': [3, None], 'has_bias': False}, combinations) +class FileSizeTestCase(test.TestCase): + + def setUp(self): + super().setUp() + + self.path_a = self.create_tempdir('dir_a').full_path + self.create_tempfile(file_path='dir_a/w.txt', content='abcd') + + self.path_b = self.create_tempdir('dir_b').full_path + self.create_tempfile(file_path='dir_b/x.txt', content='1234') + self.create_tempfile(file_path='dir_b/y.txt', content='56') + self.create_tempfile(file_path='dir_b/z.txt', content='78') + + def test_get_dir_size(self): + self.assertEqual(testing.get_dir_size(self.path_a), 4) + self.assertEqual(testing.get_dir_size(self.path_b), 8) + + def test_get_size_ratio(self): + self.assertEqual(testing.get_size_ratio(self.path_a, self.path_b), 0.5) + self.assertEqual(testing.get_size_ratio(self.path_b, self.path_a), 2.0) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD b/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD index d41a189519fd6d..7c68bb0f0c4b04 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD @@ -31,10 +31,12 @@ cc_library( ":quantization_config", ":quantization_interfaces_inc_gen", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:protos_all_cc", "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/tools/optimize:quantization_utils", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -52,6 +54,7 @@ tf_cc_test( srcs = ["quantization_driver_test.cc"], deps = [ ":quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", "//tensorflow/compiler/mlir/quantization/common:func", "//tensorflow/compiler/mlir/quantization/common:test_base", @@ -62,6 +65,8 @@ tf_cc_test( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc index 962c6656f55b65..327d109946e031 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc @@ -26,7 +26,6 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project @@ -47,39 +46,44 @@ limitations under the License. namespace mlir { namespace quant { - namespace { -// This is used to identify an operand or result of an op. The second element -// of this pair is the index of the operand or result. -using OpValue = std::pair; + +constexpr int32_t kBiasMax = std::numeric_limits::max() / 2; // Uses the type of `value` to set the initial state of the index-th result if // `as_result` is true or index-th operand if `as_result` is false. The state // is immutable if the type is a quantized type. Returns the index of this // new state in the state vector. -void InitializeStateForValue(Operation* op, const int index, const Value value, - const bool as_result, - std::vector* states, - llvm::DenseMap* value_to_state, - llvm::DenseMap* operand_states, - llvm::DenseMap* result_states) { - const auto [cached, inserted] = value_to_state->insert({value, 0}); +void InitializeStateForValue( + Operation* op, const int index, const Value value, const bool as_result, + std::vector& states, + DenseMap& value_to_state, + DenseMap& operand_states, + DenseMap& result_states) { + const auto [cached, inserted] = value_to_state.try_emplace(value, 0); if (!inserted) { - if (as_result) - (*result_states)[{op, index}] = cached->second; - else - (*operand_states)[{op, index}] = cached->second; + if (as_result) { + result_states[{op, index}] = cached->second; + } else { + operand_states[{op, index}] = cached->second; + } return; } - const QuantParams params = - quant::QuantizedType::getQuantizedElementType(value.getType()); - const bool immutable = !HasQuantParams(params); - const int next_state_index = states->size(); - states->push_back({params, immutable}); - if (as_result) - (*result_states)[{op, index}] = next_state_index; - else - (*operand_states)[{op, index}] = next_state_index; + + const QuantizedType quantized_type = + QuantizedType::getQuantizedElementType(value.getType()); + + const bool immutable = quantized_type != nullptr; + const QuantizationDriver::QuantStateIndex next_state_index = states.size(); + states.push_back({quantized_type, immutable}); + if (as_result) { + result_states[{op, index}] = next_state_index; + } else { + operand_states[{op, index}] = next_state_index; + } + cached->second = next_state_index; } @@ -87,32 +91,31 @@ void InitializeStateForValue(Operation* op, const int index, const Value value, void QuantizationDriver::InitializeArgState(const BlockArgument arg, const Value arg_value) { - const auto [cached, inserted] = value_to_state_.insert({arg_value, 0}); + const auto [cached, inserted] = value_to_state_.try_emplace(arg_value, 0); if (!inserted) { arg_states_[arg] = cached->second; return; } - const QuantParams params = - quant::QuantizedType::getQuantizedElementType(arg_value.getType()); - const bool immutable = !HasQuantParams(params); - const int next_state_index = states_.size(); - states_.push_back({params, immutable}); + + const QuantizedType quantized_type = + QuantizedType::getQuantizedElementType(arg_value.getType()); + const bool immutable = quantized_type != nullptr; + const QuantizationDriver::QuantStateIndex next_state_index = states_.size(); + states_.push_back({quantized_type, immutable}); arg_states_[arg] = next_state_index; cached->second = next_state_index; } void QuantizationDriver::InitializeOperandState(Operation* op, const int index, const Value value) { - ::mlir::quant::InitializeStateForValue(op, index, value, /*as_result=*/false, - &states_, &value_to_state_, - &operand_states_, &result_states_); + InitializeStateForValue(op, index, value, /*as_result=*/false, states_, + value_to_state_, operand_states_, result_states_); } void QuantizationDriver::InitializeResultState(Operation* op, const int index, const Value value) { - ::mlir::quant::InitializeStateForValue(op, index, value, /*as_result=*/true, - &states_, &value_to_state_, - &operand_states_, &result_states_); + InitializeStateForValue(op, index, value, /*as_result=*/true, states_, + value_to_state_, operand_states_, result_states_); } std::unique_ptr QuantizationDriver::GetQuantSpec(Operation* op) { @@ -133,11 +136,11 @@ bool QuantizationDriver::IsQuantized(Operation* op) { bool QuantizationDriver::SetConstantResultParams(Operation* op) { DenseFPElementsAttr attr; - const Value res = op->getResult(0); - if (!matchPattern(res, m_Constant(&attr))) { + const Value result = op->getResult(0); + if (!matchPattern(result, m_Constant(&attr))) { return false; } - // TODO(fengliuai): make storage_type_width and narrow_range configurable. + // TODO: b/323478683 - Make storage_type_width and narrow_range configurable. Type final_type; const auto it = optimized_weights_.find(op); const bool is_weight = it != optimized_weights_.end(); @@ -159,42 +162,44 @@ bool QuantizationDriver::SetConstantResultParams(Operation* op) { final_type = GetUniformQuantizedTypeForWeight( attr, /*symmetric=*/is_weight && is_signed_, /*num_bits=*/8, is_signed_, - /*narrow_range_=*/is_weight, legacy_float_scale_); + /*narrow_range=*/is_weight, legacy_float_scale_); } - if (const auto quant_type = - final_type.dyn_cast_or_null()) { - return SetResultParams(op, 0, quant_type); + if (const auto quant_type = final_type.dyn_cast_or_null(); + quant_type != nullptr) { + return SetResultParams(op, /*result_index=*/0, quant_type); } return false; } -bool QuantizationDriver::SetResultParams(Operation* op, const int res_index, - const QuantParams params) { - auto& state = GetResultQuantState(op, res_index); - if (state.params == params) { +bool QuantizationDriver::SetResultParams(Operation* op, const int result_index, + const QuantizedType quantized_type) { + QuantState& state = GetResultQuantState(op, result_index); + if (state.params == quantized_type) { return false; } if (!state.IsEmpty()) { - auto& rescales = GetResultRequantizeStates(op, res_index); + RequantizeStates& rescales = GetResultRequantizeStates(op, result_index); RequantizeState& rescale = rescales.emplace_back(); rescale.pos = RequantizeState::ON_INPUT; - rescale.params = params; + rescale.params = quantized_type; return true; } - state.params = params; - AddUserToList(op, res_index); + state.params = quantized_type; + AddUserToList(op, result_index); return true; } -QuantParams QuantizationDriver::GetBiasParams( - Operation* op, const int bias_index, const std::vector& non_biases, +QuantizedType QuantizationDriver::GetBiasParams( + Operation* op, const int bias_index, + const ArrayRef non_bias_operand_indices, const AccumulatorScaleFunc func) { QuantState& bias_state = GetOperandQuantState(op, bias_index); if (!bias_state.IsEmpty()) { return bias_state.params; } - std::vector op_types; - op_types.reserve(non_biases.size()); + std::vector op_types{}; + op_types.reserve(non_bias_operand_indices.size()); + int adjusted_quant_dim = -1; if (op->getNumOperands() > bias_index) { // Some kernels allow 1D bias, broadcasting it inside the kernel. In this @@ -211,68 +216,75 @@ QuantParams QuantizationDriver::GetBiasParams( } } - for (int non_bias : non_biases) { - const QuantState& non_bias_type = GetOperandQuantState(op, non_bias); - op_types.push_back(non_bias_type.params); + for (const int non_bias_operand_index : non_bias_operand_indices) { + const QuantState& non_bias_state = + GetOperandQuantState(op, non_bias_operand_index); + op_types.push_back(non_bias_state.params); } return func(op_types, adjusted_quant_dim, legacy_float_scale_); } -bool QuantizationDriver::SetOperandParams(Operation* op, const int index, - const QuantParams params, +bool QuantizationDriver::SetOperandParams(Operation* op, + const int operand_index, + const QuantizedType quantized_type, const bool override) { - auto& state = GetOperandQuantState(op, index); - if (state.params == params) { + QuantState& state = GetOperandQuantState(op, operand_index); + if (state.params == quantized_type) { return false; } if (!state.IsEmpty() && !override) { - auto& rescales = GetOperandRequantizeStates(op, index); + RequantizeStates& rescales = GetOperandRequantizeStates(op, operand_index); for (RequantizeState& rescale : rescales) { - if (rescale.params == params) { - rescale.users.emplace_back(op, index); + if (rescale.params == quantized_type) { + rescale.users.emplace_back(op, operand_index); return true; } } RequantizeState& rescale = rescales.emplace_back(); rescale.pos = RequantizeState::ON_OUTPUT; - rescale.params = params; - rescale.users.emplace_back(op, index); + rescale.params = quantized_type; + rescale.users.emplace_back(op, operand_index); return true; } - state.params = params; - AddOperandToList(op, index); + state.params = quantized_type; + AddOperandToList(op, operand_index); return true; } -void QuantizationDriver::QuantizeOpResult(Operation* op, const int index, - const QuantParams params) { +void QuantizationDriver::QuantizeOpResult(Operation* op, const int result_index, + const QuantizedType quantized_type) { builder_.setInsertionPointAfter(op); - const Value original_result = op->getResult(index); - QuantizeValue(original_result, params, op->getLoc()); + const Value original_result = op->getResult(result_index); + QuantizeValue(original_result, quantized_type, op->getLoc()); } -void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) { +void QuantizationDriver::QuantizeArg(BlockArgument arg, + const QuantizedType quantized_type) { builder_.setInsertionPointToStart(arg.getOwner()); - QuantizeValue(arg, params, builder_.getUnknownLoc()); + QuantizeValue(arg, quantized_type, builder_.getUnknownLoc()); } -void QuantizationDriver::QuantizeValue(Value value, QuantParams params, - Location loc) { +void QuantizationDriver::QuantizeValue(Value value, + QuantizedType quantized_type, + const Location loc) { const Type expressed_type = value.getType(); - const Type new_type = params.castFromExpressedType(expressed_type); - // This value isn't an expressed type (float), skip. - if (!new_type) return; + const Type new_value_type = + quantized_type.castFromExpressedType(expressed_type); + // Skip if `value` or `value`'s element type doesn't match the expressed type + // of `quantized_type`. + if (new_value_type == nullptr) return; + auto quantize = - builder_.create(loc, new_type, value); + builder_.create(loc, new_value_type, value); auto dequantize = builder_.create( loc, expressed_type, quantize.getResult()); // This attribute is set to distinguish the quantize ops being added by the // quantization pass. These ops can be removed without losing original // program accuracy. - // TODO(fengliuai): make the attribute being part of op definition. + // TODO: b/323478683 - Make the attribute being part of op definition. quantize->setAttr(kVolatileOpAttrName, builder_.getUnitAttr()); // `original_result` has a use to `quantize`, so this will replace that use @@ -281,17 +293,18 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params, quantize.getOperation()->replaceUsesOfWith(dequantize, value); } -void QuantizationDriver::RequantizeOpResult(Operation* op, const int index, - RequantizeStates* states) { - if (states->empty()) return; +void QuantizationDriver::RequantizeOpResult(Operation* op, + const int result_index, + RequantizeStates& states) { + if (states.empty()) return; builder_.setInsertionPointAfter(op); - Value value = op->getResult(index); - RequantizeState::RequantizePosition pos = states->front().pos; + Value value = op->getResult(result_index); + RequantizeState::RequantizePosition pos = states.front().pos; if (pos == RequantizeState::NO_REQUANTIZE) { return; } - for (auto& state : *states) { + for (const RequantizeState& state : states) { // Check that all requantization positions are the same for each state. // Unsure if this check is required. if (state.pos != pos) { @@ -300,7 +313,7 @@ void QuantizationDriver::RequantizeOpResult(Operation* op, const int index, } if (pos == RequantizeState::ON_OUTPUT) { Operation* user = value.getUses().begin().getUser(); - if (llvm::isa(user)) { + if (isa(user)) { // The requantize op is inserted between `quantize` and `dequantize` ops. value = user->getResult(0); builder_.setInsertionPointAfter(user); @@ -310,12 +323,12 @@ void QuantizationDriver::RequantizeOpResult(Operation* op, const int index, } void QuantizationDriver::RequantizeArg(const BlockArgument arg, - RequantizeStates* states) { + RequantizeStates& states) { Value value = arg; builder_.setInsertionPointToStart(arg.getOwner()); if (value.hasOneUse()) { Operation* user = value.use_begin().getUser(); - if (auto q = llvm::dyn_cast(user)) { + if (auto q = dyn_cast(user)) { value = q.getResult(); builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user)); } @@ -323,14 +336,13 @@ void QuantizationDriver::RequantizeArg(const BlockArgument arg, RequantizeValue(value, states, builder_.getUnknownLoc()); } -void QuantizationDriver::RequantizeValue(Value value, RequantizeStates* states, +void QuantizationDriver::RequantizeValue(Value value, RequantizeStates& states, const Location loc) { - if (states->empty() || - states->front().pos == RequantizeState::NO_REQUANTIZE) { + if (states.empty() || states.front().pos == RequantizeState::NO_REQUANTIZE) { return; } - if (states->front().pos == RequantizeState::ON_INPUT) { - auto& state = states->front(); + if (states.front().pos == RequantizeState::ON_INPUT) { + RequantizeState& state = states.front(); const Type expressed_type = value.getType(); // The value needs to be requantized. A Quantize op will be created to use // it as the operand and replace its uses. @@ -350,7 +362,7 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeStates* states, if (!value.hasOneUse()) { return; } - auto dequant_op = llvm::dyn_cast_or_null( + auto dequant_op = dyn_cast_or_null( value.use_begin().getUser()); if (!dequant_op) { return; @@ -363,10 +375,9 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeStates* states, // Whether to replace quantization params of the first dequantize op // after the quantized value is produced. // If there is a use other than the requantize states, then we can't clobber. - bool clobber_first = num_uses <= states->size(); - for (auto& state : *states) { - Type expressed_type = - quant::QuantizedType::castToExpressedType(value.getType()); + bool clobber_first = num_uses <= states.size(); + for (RequantizeState& state : states) { + Type expressed_type = QuantizedType::castToExpressedType(value.getType()); if (!expressed_type) continue; // The value needs to be requantized. A Quantize op will be created to use // it as the operand and replace its uses. @@ -384,8 +395,8 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeStates* states, } else { auto new_dequant_op = builder_.create( loc, dequant_op.getResult().getType(), requantize_op.getResult()); - for (auto& op_index : state.users) { - op_index.first->setOperand(op_index.second, new_dequant_op.getResult()); + for (auto [op, operand_idx] : state.users) { + op->setOperand(operand_idx, new_dequant_op.getResult()); } } } @@ -400,12 +411,12 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeStates* states, // - use the single input if it is ready, or, // - use the single output if it is ready, or, // - use the first ready one in the collection. -QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint( +QuantizedType QuantizationDriver::GetQuantParamsForSameScaleConstraint( Operation* op) { // Two vector to collect Non-empty operands and results states. std::vector mutable_states, immutable_states; for (int i = 0; i < op->getNumOperands(); ++i) { - auto& state = GetOperandQuantState(op, i); + QuantState& state = GetOperandQuantState(op, i); if (state.immutable) { immutable_states.push_back(&state); } else if (!state.IsEmpty()) { @@ -422,7 +433,7 @@ QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint( } for (int i = 0; i < op->getNumResults(); ++i) { - auto& state = GetResultQuantState(op, i); + QuantState& state = GetResultQuantState(op, i); if (state.immutable) { immutable_states.push_back(&state); } else if (!state.IsEmpty()) { @@ -476,14 +487,11 @@ void QuantizationDriver::PreprocessConstantOps() { // The following loop will change the value uses, thus we cache all the uses // needs to be changed. - llvm::SmallVector> uses; - for (auto& use : value.getUses()) { + SmallVector> uses; + for (OpOperand& use : value.getUses()) { uses.push_back({use.getOwner(), use.getOperandNumber()}); } - for (const auto& indexed_use : llvm::enumerate(uses)) { - Operation* user = indexed_use.value().first; - const int operand_num = indexed_use.value().second; - + for (const auto [user, operand_num] : uses) { const std::unique_ptr spec = GetQuantSpec(user); const std::unique_ptr scale_spec = GetQuantScaleSpec(user); @@ -493,9 +501,9 @@ void QuantizationDriver::PreprocessConstantOps() { // other values. So any constants which are not bias, an operand of an // op with same scale requirements, and haven't been quantized are // weights. - if (biases.find(operand_num) == biases.end() && + if (!biases.contains(operand_num) && !scale_spec->has_same_scale_requirement && - !llvm::dyn_cast(user)) { + !dyn_cast(user)) { // Needs to scan the content of weights to get the quantization // parameters if there are no quantization parameters (FakeQuant ops). // For this case, the weight will not be duplicated. @@ -511,9 +519,9 @@ void QuantizationDriver::PreprocessConstantOps() { // other values. Duplicate this constant in case it is shared by // different users. if (uses.size() > 1) { - auto new_cst = + auto new_constant_op = builder_.create(cst.getLoc(), cst.getValue()); - user->setOperand(operand_num, new_cst); + user->setOperand(operand_num, new_constant_op); } } } @@ -521,13 +529,13 @@ void QuantizationDriver::PreprocessConstantOps() { } void QuantizationDriver::SetupAllStates() { - for (auto arg : fn_.getArguments()) { + for (BlockArgument arg : fn_.getArguments()) { args_.push_back(arg); Value value = arg; // If the argument is quantized, it should only has one user. if (arg.hasOneUse()) { Operation* user = value.use_begin().getUser(); - if (auto q = llvm::dyn_cast(user)) { + if (auto q = dyn_cast(user)) { value = q.getResult(); } } @@ -543,29 +551,29 @@ void QuantizationDriver::SetupAllStates() { for (int i = 0; i < op->getNumOperands(); ++i) { Value operand = op->getOperand(i); - if (auto* inst = operand.getDefiningOp()) { + if (Operation* inst = operand.getDefiningOp()) { // If the operand comes from a `quantfork::DequantizeCastOp`, we use // the quantized input of this `quantfork::DequantizeCastOp` to set the // state. - if (auto dq = llvm::dyn_cast(inst)) { + if (auto dq = dyn_cast(inst)) { operand = dq.getArg(); } } InitializeOperandState(op, i, operand); } - for (int res = 0; res < op->getNumResults(); ++res) { - Value result = op->getResult(res); + for (int i = 0; i < op->getNumResults(); ++i) { + Value result = op->getResult(i); // If the result has been quantized, it should only be used by a // `quantfork::QuantizeCastOp`. For this case, we uses the quantized // result to create the state and mark it immutable. if (result.hasOneUse()) { Operation* user = result.use_begin().getUser(); - if (auto q = llvm::dyn_cast(user)) { + if (auto q = dyn_cast(user)) { result = q.getResult(); } } - InitializeResultState(op, res, result); + InitializeResultState(op, i, result); } }); } @@ -577,7 +585,7 @@ arith::ConstantOp QuantizationDriver::DuplicateConstantOpIfNeeded( } OpBuilder builder(op->getContext()); builder.setInsertionPointAfter(op); - arith::ConstantOp new_op = llvm::cast(builder.clone(*op)); + arith::ConstantOp new_op = cast(builder.clone(*op)); target_op->getOpOperand(operand_index).set(new_op.getResult()); InitializeOperandState(target_op, operand_index, new_op.getResult()); InitializeResultState(new_op, 0, new_op.getResult()); @@ -585,13 +593,13 @@ arith::ConstantOp QuantizationDriver::DuplicateConstantOpIfNeeded( } bool QuantizationDriver::ShouldCheckBiasScale( - Operation* op, const int bias_index, const std::vector& input_indices, - const QuantParams params, int& input_index, int& filter_index) { + Operation* op, const int bias_index, ArrayRef input_indices, + const QuantizedType quantized_type, int& input_index, int& filter_index) { // For now, restrict scale adjustment to ops with affine quantized weights, // and having weights and biases as constants. This currently only applies to // FC and Conv* ops. Restriction for the weight can be relaxed if there are // needs for adjusting scale of variable weights. - auto affine_op = llvm::dyn_cast(op); + auto affine_op = dyn_cast(op); auto bias_op = op->getOperand(bias_index).getDefiningOp(); if (!affine_op || !bias_op || input_indices.size() != 2) return false; if (!bias_op.getValue().isa()) return false; @@ -607,22 +615,20 @@ bool QuantizationDriver::ShouldCheckBiasScale( return false; } - const auto input_state = GetOperandQuantState(op, input_index); - const auto filter_state = GetOperandQuantState(op, filter_index); + const QuantState& input_state = GetOperandQuantState(op, input_index); + const QuantState& filter_state = GetOperandQuantState(op, filter_index); // If quantization parameter for the filter is fixed, should return it as-is. // Only checks ops with 8-bit input and weights, and 32-bit biases. - if (!(input_state.params.getStorageTypeIntegralWidth() == 8 && - filter_state.params.getStorageTypeIntegralWidth() == 8 && - params.getStorageTypeIntegralWidth() == 32)) { - return false; - } - return true; + return input_state.params.getStorageTypeIntegralWidth() == 8 && + filter_state.params.getStorageTypeIntegralWidth() == 8 && + quantized_type.getStorageTypeIntegralWidth() == 32; } bool QuantizationDriver::SetBiasParamsWithAdjustments( - Operation* op, const int bias_index, const std::vector& input_indices, - const QuantParams params) { + Operation* op, const int bias_index, ArrayRef input_indices, + const QuantizedType params) { bool changed = false; + int input_index; int filter_index; if (!ShouldCheckBiasScale(op, bias_index, input_indices, params, input_index, @@ -630,8 +636,8 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( return SetOperandParams(op, bias_index, params); } - quant::QuantState input_state = GetOperandQuantState(op, input_index); - quant::QuantState filter_state = GetOperandQuantState(op, filter_index); + QuantState input_state = GetOperandQuantState(op, input_index); + QuantState filter_state = GetOperandQuantState(op, filter_index); auto bias_op = op->getOperand(bias_index).getDefiningOp(); const double input_scale = input_state.params.cast().getScale(); @@ -639,15 +645,15 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( auto bias_values = bias_op.getValue().cast(); // Restrict maximum absolute value of bias within INT_MAX / 2, to make some // room for accumulator. - const int32_t kBiasMax = std::numeric_limits::max() / 2; - if (auto bias_params = params.dyn_cast()) { + if (auto bias_quantized_type = params.dyn_cast(); + bias_quantized_type != nullptr) { double bias_half_range = 0.0f; for (auto bias : bias_values.getValues()) { if (bias_half_range < std::abs(bias.convertToFloat())) { bias_half_range = std::abs(bias.convertToFloat()); } } - if (bias_half_range / bias_params.getScale() < kBiasMax) { + if (bias_half_range / bias_quantized_type.getScale() < kBiasMax) { return SetOperandParams(op, bias_index, params); } const double new_bias_scale = @@ -659,30 +665,36 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( bias_op->getLoc(), params.getFlags(), params.getStorageType(), params.getExpressedType(), new_bias_scale, 0, params.getStorageTypeMin(), params.getStorageTypeMax())); - auto filter_op = DuplicateConstantOpIfNeeded( + arith::ConstantOp filter_op = DuplicateConstantOpIfNeeded( op->getOperand(filter_index).getDefiningOp(), op, filter_index); if (!filter_op) { return SetOperandParams(op, bias_index, params); } - const auto filter_param = filter_state.params.cast(); + const auto filter_quantized_type = + filter_state.params.cast(); changed |= SetOperandParams( op, filter_index, UniformQuantizedType::getChecked( - filter_op->getLoc(), filter_param.getFlags(), - filter_param.getStorageType(), filter_param.getExpressedType(), - new_bias_scale / input_scale, 0, filter_param.getStorageTypeMin(), - filter_param.getStorageTypeMax()), + filter_op->getLoc(), filter_quantized_type.getFlags(), + filter_quantized_type.getStorageType(), + filter_quantized_type.getExpressedType(), + new_bias_scale / input_scale, 0, + filter_quantized_type.getStorageTypeMin(), + filter_quantized_type.getStorageTypeMax()), /*override=*/true); - } else if (auto bias_params = - params.dyn_cast()) { - const auto filter_params = + } else if (auto bias_quantized_type = + params.dyn_cast(); + bias_quantized_type != nullptr) { + const auto filter_quantized_type = filter_state.params.cast(); - std::vector new_bias_scales = bias_params.getScales().vec(); - std::vector new_filter_scales = filter_params.getScales().vec(); + std::vector new_bias_scales = bias_quantized_type.getScales().vec(); + std::vector new_filter_scales = + filter_quantized_type.getScales().vec(); + bool needs_adjustment = false; - for (int i = 0; i < bias_params.getScales().size(); ++i) { + for (int i = 0; i < bias_quantized_type.getScales().size(); ++i) { const float abs_bias = std::abs(bias_values.getValues()[i]); if (abs_bias / new_bias_scales[i] > kBiasMax) { new_bias_scales[i] = static_cast(abs_bias) / kBiasMax; @@ -698,21 +710,23 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( quant::UniformQuantizedPerAxisType::getChecked( bias_op->getLoc(), params.getFlags(), params.getStorageType(), params.getExpressedType(), new_bias_scales, - bias_params.getZeroPoints(), bias_params.getQuantizedDimension(), + bias_quantized_type.getZeroPoints(), + bias_quantized_type.getQuantizedDimension(), params.getStorageTypeMin(), params.getStorageTypeMax())); - auto filter_op = DuplicateConstantOpIfNeeded( + arith::ConstantOp filter_op = DuplicateConstantOpIfNeeded( op->getOperand(filter_index).getDefiningOp(), op, filter_index); changed |= SetOperandParams( op, filter_index, quant::UniformQuantizedPerAxisType::getChecked( - filter_op->getLoc(), filter_params.getFlags(), - filter_params.getStorageType(), filter_params.getExpressedType(), - new_filter_scales, filter_params.getZeroPoints(), - filter_params.getQuantizedDimension(), - filter_params.getStorageTypeMin(), - filter_params.getStorageTypeMax()), + filter_op->getLoc(), filter_quantized_type.getFlags(), + filter_quantized_type.getStorageType(), + filter_quantized_type.getExpressedType(), new_filter_scales, + filter_quantized_type.getZeroPoints(), + filter_quantized_type.getQuantizedDimension(), + filter_quantized_type.getStorageTypeMin(), + filter_quantized_type.getStorageTypeMax()), /*override=*/true); } return changed; @@ -720,12 +734,12 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( // This method scans the operations in the function to setup the initial // states for quantization parameter propagation. -// TODO(fengliuai): This algorithm assumes there are only one pair of +// TODO: b/323478683 - This algorithm assumes there are only one pair of // `quantfork::QuantizeCastOp` and `quantfork::DequantizeCastOp` ops between two // quantizable ops. A sanity check should be applied. void QuantizationDriver::Initialize() { // Duplicate the bias constant, so the states can be setup correctly. - // TODO(fengliuai): Function definition should also be duplicated if there + // TODO: b/323478683 - Function definition should also be duplicated if there // are multiple call sites. PreprocessConstantOps(); @@ -736,21 +750,21 @@ void QuantizationDriver::Initialize() { // Propagates the quantization parameters to the operands, results, and biases. // TODO: b/323478683 - Do not use while loop to handle this logic. bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { - // TODO(fengliuai): uses a typed indicator instead of a bool value. + // TODO: b/323478683 - Use a typed indicator instead of a bool value. bool changed = false; while (!work_list_.empty()) { Operation* op = work_list_.back(); work_list_.pop_back(); // This op has been quantized, so we should not consider it again. - if (llvm::is_contained(quantized_, op)) continue; + if (quantized_.contains(op)) continue; quantized_.insert(op); - if (auto cst = llvm::dyn_cast(op)) { + if (auto constant_op = dyn_cast(op); constant_op) { // If the workflow requires inferring ranges from the content // (post-training quantization) and it is weight (filter) and hasn't // been quantized, we infer the quantization parameters from the content. - if (infer_tensor_range_ && IsWeight(cst) && !IsQuantized(op)) { + if (infer_tensor_range_ && IsWeight(constant_op) && !IsQuantized(op)) { // The quantization parameters are determined by the content of the // constant. changed |= SetConstantResultParams(op); @@ -761,7 +775,7 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { std::unique_ptr scale_spec = GetQuantScaleSpec(op); if (scale_spec->has_same_scale_requirement) { - const auto params = GetQuantParamsForSameScaleConstraint(op); + const QuantizedType params = GetQuantParamsForSameScaleConstraint(op); // The quantization parameters haven't been propagated to any operands // or results. Skip this node for now. if (!params) { @@ -792,12 +806,13 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { } // Use the final state to set all the results' parameters. - for (int res = 0; res < op->getNumResults(); ++res) - if (auto type = op->getResult(res).getType().dyn_cast()) { + for (int i = 0; i < op->getNumResults(); ++i) + if (auto type = op->getResult(i).getType().dyn_cast(); + type != nullptr) { // Without this check, it will accidentally propagate the quantization // information by the shared non-float-tensors. if (type.getElementType().isa()) - changed |= SetResultParams(op, res, params); + changed |= SetResultParams(op, i, params); } } @@ -807,8 +822,8 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { !is_qdq_conversion_) { // Infer ranges from the activation ops. This is usually required for // the post-training quantization workflow. - // TODO(fengliuai): different result can have different fixed range. - const auto params = + // TODO: b/323478683 - Different result can have different fixed range. + const QuantizedType params = scale_spec->fixed_output_range_func(is_signed_, bit_width_); for (auto i = 0; i < op->getNumResults(); ++i) { // The range is null if the result has been quantized. @@ -818,16 +833,20 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { } } - const auto spec = GetQuantSpec(op); - for (auto& it : spec->biases_params) { - const auto params = - GetBiasParams(op, it.first, it.second.first, it.second.second); + const std::unique_ptr spec = GetQuantSpec(op); + for (const auto& [bias_operand_idx, non_bias_params] : + spec->biases_params) { + const auto& [non_bias_operand_indices, accumulator_scale_func] = + non_bias_params; + const QuantizedType params = + GetBiasParams(op, bias_operand_idx, non_bias_operand_indices, + accumulator_scale_func); if (!params) { quantized_.erase(op); continue; } - changed |= - SetBiasParamsWithAdjustments(op, it.first, it.second.first, params); + changed |= SetBiasParamsWithAdjustments(op, bias_operand_idx, + non_bias_operand_indices, params); } } @@ -836,9 +855,9 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { // Finalizes the arguments and result states in the function. void QuantizationDriver::Finalize() { - for (auto arg : args_) { - auto& state = GetArgQuantState(arg); - auto& requantizes = GetArgRequantizeStates(arg); + for (BlockArgument arg : args_) { + const QuantState& state = GetArgQuantState(arg); + RequantizeStates& requantizes = GetArgRequantizeStates(arg); if (state.IsEmpty() || (state.immutable && requantizes.empty())) { continue; } @@ -848,25 +867,24 @@ void QuantizationDriver::Finalize() { } if (!requantizes.empty()) { - RequantizeArg(arg, &requantizes); + RequantizeArg(arg, requantizes); } } - for (auto it : result_states_) { - Operation* op = it.first.first; - const int res_index = it.first.second; - auto& state = GetResultQuantState(op, res_index); - auto& requantizes = GetResultRequantizeStates(op, res_index); + for (const auto& [op_with_result_idx, quant_state_idx] : result_states_) { + const auto [op, result_idx] = op_with_result_idx; + const QuantState& state = GetResultQuantState(op, result_idx); + RequantizeStates& requantizes = GetResultRequantizeStates(op, result_idx); if (state.IsEmpty() || (state.immutable && requantizes.empty())) { continue; } if (!state.immutable) { - QuantizeOpResult(op, res_index, state.params); + QuantizeOpResult(op, result_idx, state.params); } if (!requantizes.empty()) { - RequantizeOpResult(op, res_index, &requantizes); + RequantizeOpResult(op, result_idx, requantizes); } } } @@ -885,7 +903,7 @@ void QuantizationDriver::Run() { } void ApplyQuantizationParamsPropagation( - const mlir::func::FuncOp func, const bool is_signed, const int bit_width, + const func::FuncOp func, const bool is_signed, const int bit_width, const bool disable_per_channel, const OpQuantSpecGetter op_quant_spec_getter, const bool infer_tensor_ranges, const bool legacy_float_scale, @@ -897,7 +915,7 @@ void ApplyQuantizationParamsPropagation( } void ApplyQuantizationParamsPropagation( - const mlir::func::FuncOp func, const bool is_signed, const int bit_width, + const func::FuncOp func, const bool is_signed, const int bit_width, const bool disable_per_channel, const OpQuantSpecGetter op_quant_spec_getter, const OpQuantScaleSpecGetter op_quant_scale_spec_getter, diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h index 59741f48307a16..070ecb75f5db5b 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h @@ -17,14 +17,13 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_DRIVER_H_ #include -#include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project @@ -40,20 +39,16 @@ limitations under the License. namespace mlir { namespace quant { -static bool HasQuantParams(QuantParams p) { - return p == quant::QuantizedType(); -} - // The state for each op result during the quantization parameters propagation. struct QuantState { // Quantization parameters propagated to an op result. - QuantParams params; + QuantizedType params; // A flag indicates this state (the params) shouldn't be changed after it is // initialized. This flag will be set to true if the quantization parameters // are from the quantization-aware training. const bool immutable; - bool IsEmpty() { return HasQuantParams(params); } + bool IsEmpty() const { return params == nullptr; } }; // The state for rescaling the propagated quantization parameters. This can be @@ -70,7 +65,7 @@ struct RequantizeState { } pos = NO_REQUANTIZE; // Quantization parameters will be used to add the requantize ops. - QuantParams params; + QuantizedType params; // Avoid clobbering all uses of the value, limit to just these ops. SmallVector> users; @@ -99,15 +94,25 @@ using RequantizeStates = SmallVector; // class QuantizationDriver { public: - explicit QuantizationDriver(func::FuncOp fn, bool is_signed, int bit_width, - bool disable_per_channel, + // Type alias of int used to access `states_`. + using QuantStateIndex = int; + + // (op, operand index) pair. + using OpWithOperandIndex = std::pair; + + // (op, result index) pair. + using OpWithResultIndex = std::pair; + + explicit QuantizationDriver(func::FuncOp func_op, const bool is_signed, + const int bit_width, + const bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, OpQuantScaleSpecGetter op_quant_scale_spec_getter, - bool infer_tensor_range, - bool legacy_float_scale = false, - bool is_qdq_conversion = false) - : fn_(fn), - builder_(fn.getBody()), + const bool infer_tensor_range, + const bool legacy_float_scale = false, + const bool is_qdq_conversion = false) + : fn_(func_op), + builder_(func_op.getBody()), is_signed_(is_signed), bit_width_(bit_width), disable_per_channel_(disable_per_channel), @@ -130,18 +135,25 @@ class QuantizationDriver { // result. void Finalize(); - llvm::SmallVector GetArgs() { return args_; } + SmallVector GetArgs() { return args_; } + + llvm::DenseMap, int> GetResultStates() { + return result_states_; + } + + DenseMap result_states_; // Returns the state of the block argument. QuantState& GetArgQuantState(BlockArgument arg) { return states_[arg_states_[arg]]; } - private: - // This is used to identify an operand or result of an op. The second element - // of this pair is the index of the operand or result. - using OpValue = std::pair; + // Returns the state of the index-th result of the op. + QuantState& GetResultQuantState(Operation* op, const int index) { + return states_[result_states_[{op, index}]]; + } + private: // Duplicates the constant op if it has multiple uses, and replaces // target_op->operand[operand_index] with the newly created op. This also // replaces corresponsing quantization states. @@ -153,13 +165,13 @@ class QuantizationDriver { // prevent overflow of quantized bias values. This also changes quantization // state of other inputs when needed. bool SetBiasParamsWithAdjustments(Operation* op, int bias_index, - const std::vector& input_indices, - QuantParams params); + ArrayRef input_indices, + QuantizedType params); // Checks preconditions to adjust bias scale. bool ShouldCheckBiasScale(Operation* op, int bias_index, - const std::vector& input_indices, - QuantParams params, int& input_index, + ArrayRef input_indices, + QuantizedType quantized_type, int& input_index, int& filter_index); // Preprocesses the constants by doing the following: @@ -187,84 +199,87 @@ class QuantizationDriver { bool IsQuantized(Operation* op); // Adds all the users of index-th result of op to the work list. - void AddUserToList(Operation* op, int index) { + void AddUserToList(Operation* op, const int index) { for (Operation* user : op->getResult(index).getUsers()) { work_list_.push_back(user); } } // Adds the defining op of index-th operand of op to the work list. - void AddOperandToList(Operation* op, int index) { - if (Operation* inst = op->getOperand(index).getDefiningOp()) { - work_list_.push_back(inst); + void AddOperandToList(Operation* op, const int index) { + if (Operation* operand_op = op->getOperand(index).getDefiningOp(); + operand_op != nullptr) { + work_list_.push_back(operand_op); } } // Returns the quantization params for the bias input from the non-bias // operands which have their indexes in the `non_biases` vector. The returned // parameters are calculated by `func`. - QuantParams GetBiasParams(Operation* op, int bias_index, - const std::vector& non_biases, - AccumulatorScaleFunc func); - - // Sets the quantization parameters of the result to a fixed value. If any - // quantization parameters have been propagated, a `requantize` will happen on - // the input of propagated quantization. - bool SetResultParams(Operation* op, int index, QuantParams params); - - // Sets the quantization parameters of the operand to a fixed value. If any + QuantizedType GetBiasParams(Operation* op, int bias_index, + ArrayRef non_bias_operand_indices, + AccumulatorScaleFunc func); + + // Sets the quantization parameters of the result to `quantized_type`. If + // any quantization parameters have been propagated, a requantize will + // happen on the input of propagated quantization. Returns `true` if internal + // state has been modified. + bool SetResultParams(Operation* op, int result_index, + QuantizedType quantized_type); + + // Sets the quantization parameters of the operand to `quantized_type`. If any // quantization parameters have been propagated, a `requantize` will happen on // the output of propagated quantization. When `override` is set, quantization - // state of the value is replaced instead of adding requantization. - bool SetOperandParams(Operation* op, int index, QuantParams params, - bool override = false); + // state of the value is replaced instead of adding requantization. Returns + // `true` if internal state has been modified. + bool SetOperandParams(Operation* op, int operand_index, + QuantizedType quantized_type, bool override = false); // Sets the quantization parameters of the constant result according to its // content. bool SetConstantResultParams(Operation* op); - // Inserts the Quantize and Dequantize ops for quantizing the index-th result - // of the op. - void QuantizeOpResult(Operation* op, int index, QuantParams params); + // Inserts the Quantize and Dequantize ops after `op`'s `index`-th result. The + // quantized element type for the result is `quantized_type`. + void QuantizeOpResult(Operation* op, int result_index, + QuantizedType quantized_type); - void QuantizeArg(BlockArgument arg, QuantParams params); + // Inserts the Quantize and Dequantize ops after `arg`. The quantized element + // type for `arg` is `quantized_type`. + void QuantizeArg(BlockArgument arg, QuantizedType quantized_type); - // Inserts the Quantize and Dequantize ops to quantize the value and returns - // the Quantize op. - void QuantizeValue(Value value, QuantParams params, Location loc); + // Inserts the Quantize and Dequantize ops (i.e. QDQ) after `value`. The + // quantized element type for `value` is `quantized_type`. + void QuantizeValue(Value value, QuantizedType quantized_type, Location loc); // Inserts the Quantize ops for requantizing the index-th result of the op. - void RequantizeOpResult(Operation* op, int index, RequantizeStates* states); + void RequantizeOpResult(Operation* op, int result_index, + RequantizeStates& states); // Inserts the Quantize ops for requantizing a block argument. - void RequantizeArg(BlockArgument arg, RequantizeStates* states); + void RequantizeArg(BlockArgument arg, RequantizeStates& states); // Inserts the Quantize and Dequantize ops to quantize the value and returns // the Quantize op. - void RequantizeValue(Value value, RequantizeStates* states, Location loc); + void RequantizeValue(Value value, RequantizeStates& states, Location loc); // Returns the quantization parameter satisfies the same scale // constraints for the op. Returns an empty option if this quantization // parameter doesn't exist. - QuantParams GetQuantParamsForSameScaleConstraint(Operation* op); + QuantizedType GetQuantParamsForSameScaleConstraint(Operation* op); // Returns the state of the index-th operand of the op. - QuantState& GetOperandQuantState(Operation* op, int index) { + QuantState& GetOperandQuantState(Operation* op, const int index) { return states_[operand_states_[{op, index}]]; } - // Returns the state of the index-th result of the op. - QuantState& GetResultQuantState(Operation* op, int index) { - return states_[result_states_[{op, index}]]; - } - // Returns the states of the index-th operand of the op. - RequantizeStates& GetOperandRequantizeStates(Operation* op, int index) { + RequantizeStates& GetOperandRequantizeStates(Operation* op, const int index) { return rescale_states_[operand_states_[{op, index}]]; } // Returns the states of the index-th result of the op. - RequantizeStates& GetResultRequantizeStates(Operation* op, int index) { + RequantizeStates& GetResultRequantizeStates(Operation* op, const int index) { return rescale_states_[result_states_[{op, index}]]; } @@ -278,10 +293,6 @@ class QuantizationDriver { // a new entry in the state vector. void InitializeArgState(BlockArgument arg, Value arg_value); - // Sets the state of index-th operand / result of op. - void InitializeStateForValue(Operation* op, int index, Value value, - bool as_result); - // Sets the state of the index-th operand of the op. If this operand is // cached, uses the cached result without creating new entry in the state // vector. Otherwise, allocate a new entry in the state vector. @@ -301,12 +312,13 @@ class QuantizationDriver { // We should distinguish weights and bias constants. Biases are specified by // the quantization spec or are the operands of ops with same scale spec. The // rest are weights. - llvm::DenseSet weights_; + DenseSet weights_; // The weights require narrow_range quantization. This map collects all the - // weight operands defined by the op quant spec. If the value of the entry is - // positive, per-channel quantization is required. - llvm::DenseMap optimized_weights_; + // weight operands defined by the op quant spec. The value of each entry is + // the quantization dimension. If it is positive, per-channel quantization is + // required. + DenseMap optimized_weights_; // All the ops needs to propagate the quantization parameters to. std::vector work_list_; @@ -319,18 +331,17 @@ class QuantizationDriver { // The map contains all the quantization parameters which are required to // satisfy the same operands and results constraint. The keys of this map are // the values from `operand_states_` and `result_state_`. - std::unordered_map rescale_states_; + absl::flat_hash_map rescale_states_; // Maps of indexes to the propagation state vector from the ops operands, // results and arguments. - llvm::DenseMap operand_states_; - llvm::DenseMap result_states_; - llvm::DenseMap arg_states_; - llvm::DenseMap value_to_state_; + DenseMap operand_states_; + DenseMap arg_states_; + DenseMap value_to_state_; // This vector is to preserve the arguments order, so the newly inserted // quantized ops for the arguments are deterministically ordered. - llvm::SmallVector args_; + SmallVector args_; OpQuantSpecGetter op_quant_spec_getter_; OpQuantScaleSpecGetter op_quant_scale_spec_getter_; @@ -357,7 +368,7 @@ class QuantizationDriver { // Setting `infer_tensor_range` to true, to infer quantization parameters from // the activation ops and weight constants. This is only used for post-training // quantization. -void ApplyQuantizationParamsPropagation(mlir::func::FuncOp func, bool is_signed, +void ApplyQuantizationParamsPropagation(func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, bool infer_tensor_ranges, @@ -365,8 +376,8 @@ void ApplyQuantizationParamsPropagation(mlir::func::FuncOp func, bool is_signed, bool is_qdq_conversion); void ApplyQuantizationParamsPropagation( - mlir::func::FuncOp func, bool is_signed, int bit_width, - bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, + func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges, bool legacy_float_scale, bool is_qdq_conversion); diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc index 1942ae56b0aba4..cc82c09894b46b 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc @@ -26,12 +26,16 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" @@ -80,7 +84,8 @@ std::unique_ptr GetOpQuantSpec( TEST_F(ApplyQuantizationParamsPropagationTest, ConstsUsedMultipleTimesAreDuplicated) { - OwningOpRef module_op_ref = ParseModuleOpString(kModuleTFLite); + const OwningOpRef module_op_ref = + ParseModuleOpString(kModuleTFLite); func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); auto op_quant_spec_getter = [&](Operation* op) { @@ -97,14 +102,13 @@ TEST_F(ApplyQuantizationParamsPropagationTest, int64_t num_constant_op = 0; main_fn.walk([&](arith::ConstantOp cst) { ++num_constant_op; }); - // TODO: b/323478683 - This should actually be 3. Bias parameter is - // duplicated one extra time. Tackle this in a follow-up cl. EXPECT_EQ(num_constant_op, 4); } TEST_F(ApplyQuantizationParamsPropagationTest, PropagateParamsCreatesQuantState) { - OwningOpRef module_op_ref = ParseModuleOpString(kModuleTFLite); + const OwningOpRef module_op_ref = + ParseModuleOpString(kModuleTFLite); func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); auto op_quant_spec_getter = [&](Operation* op) { @@ -120,16 +124,23 @@ TEST_F(ApplyQuantizationParamsPropagationTest, quantization_driver.Initialize(); ASSERT_TRUE(quantization_driver.PropagateParamsAndReturnIfChanged()); EXPECT_THAT(quantization_driver.GetArgs(), Not(IsEmpty())); + for (const auto& arg : quantization_driver.GetArgs()) { - QuantState& state = quantization_driver.GetArgQuantState(arg); - // TODO: b/323478683 - Below should not be empty. Inspect further to see - // if there is a bug. - EXPECT_TRUE(state.IsEmpty()); + const QuantState& state = quantization_driver.GetArgQuantState(arg); + EXPECT_TRUE(isa(state.params)); + } + for (const auto& result : quantization_driver.GetResultStates()) { + Operation* op = result.first.first; + const int res_index = result.first.second; + const QuantState state = + quantization_driver.GetResultQuantState(op, res_index); + EXPECT_TRUE(isa(state.params)); } } TEST_F(ApplyQuantizationParamsPropagationTest, FinalizeInsertsQDQOps) { - OwningOpRef module_op_ref = ParseModuleOpString(kModuleTFLite); + const OwningOpRef module_op_ref = + ParseModuleOpString(kModuleTFLite); func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); auto op_quant_spec_getter = [&](Operation* op) { @@ -146,8 +157,12 @@ TEST_F(ApplyQuantizationParamsPropagationTest, FinalizeInsertsQDQOps) { xla_call_module_op->getOperand(1).getDefiningOp(); Operation* filter_qcast_op = filter_dcast_op->getOperand(0).getDefiningOp(); ASSERT_NE(filter_qcast_op, nullptr); - // TODO: b/323478683 - Add check for `UniformQuantizedPerAxisType` below. - EXPECT_TRUE(filter_qcast_op->getResult(0).getType().isa()); + EXPECT_TRUE(isa(filter_qcast_op)); + EXPECT_TRUE(isa(filter_dcast_op)); + EXPECT_TRUE(isa(filter_qcast_op->getResult(0) + .getType() + .cast() + .getElementType())); } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc index 5021805a879ef3..f6c561be98d49b 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc @@ -46,10 +46,10 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h" -#include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" #include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" #include "tensorflow/lite/tools/optimize/quantization_utils.h" diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h index d95ba49cf8e800..e1d36df58a3fd9 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h @@ -26,10 +26,10 @@ limitations under the License. #include #include #include -#include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "llvm/ADT/DenseMap.h" @@ -54,8 +54,8 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" #include "tensorflow/core/framework/types.pb.h" @@ -86,11 +86,11 @@ inline constexpr double kNearZeroTolerance = 1.0e-6; using QuantParams = QuantizedType; using QuantSpec = QuantizationSpecs; using SignedInteger = std::pair; // bitwidth and sign -using QuantParamsForResults = llvm::SmallVector; +using QuantParamsForResults = llvm::SmallVector; using AccumulatorScaleFunc = - std::function&, int, bool)>; + std::function&, int, bool)>; using BiasParamsMap = - std::unordered_map, AccumulatorScaleFunc>>; + absl::flat_hash_map, AccumulatorScaleFunc>>; // UniformQuantizedType GetFixedOutputRange(bool sign, int bit_width) using GetFixedOutputRangeFunc = std::function; // bool RequiredSameOperandsAndResultsScale(bool sign, int $bit_width) diff --git a/tensorflow/compiler/mlir/quantization/common/test_base.h b/tensorflow/compiler/mlir/quantization/common/test_base.h index 46c069cc49011e..a1a770ff616dee 100644 --- a/tensorflow/compiler/mlir/quantization/common/test_base.h +++ b/tensorflow/compiler/mlir/quantization/common/test_base.h @@ -28,9 +28,11 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/context.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -69,6 +71,22 @@ class QuantizationTestBase : public Test { return nullptr; } + // Convenience function that returns the first operation of type `OpT` from + // the `@main` function in `module_op`. Useful when testing with a text + // representation of a `ModuleOp` containing a single function `@main`. + // Returns `failure` iff there is no `@main` or no such operation is found in + // `@main`. + template + FailureOr FindFirstOpFromMainFunc(ModuleOp module_op) { + func::FuncOp main_func_op = FindMainFuncOp(module_op); + if (main_func_op == nullptr) return failure(); + + auto ops = main_func_op.getOps(); + if (ops.empty()) return failure(); + + return *ops.begin(); + } + std::unique_ptr ctx_; OpBuilder builder_; }; diff --git a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h index 6c02f0d1dcbfd5..ab850c878ff0dd 100644 --- a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { @@ -78,6 +79,12 @@ bool IsStorageTypeI32(QuantizedType quantized_type); bool IsExpressedTypeF32(QuantizedType quantized_type); +// Given a value, extract the `ElementType`. +// `value` should be a non-null `TensorType`. +inline Type GetElementType(const Value value) { + return value.getType().cast().getElementType(); +} + // Returns true iff `type` is a uniform quantized type whose storage type is // 8-bit integer and expressed type is f32. bool IsI8F32UniformQuantizedType(Type type); diff --git a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc index 474c378acc1e0d..e9443a667fcef3 100644 --- a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/quantization/common/test_base.h" @@ -725,6 +726,28 @@ TEST_F(IsOpNotQuantizedTest, FalseIfOpPartiallyQuantized) { EXPECT_FALSE(IsOpNotQuantized(*uniform_quantize_op_itr)); } +using UniformQuantizedTypeTest = QuantizationTestBase; + +TEST_F(UniformQuantizedTypeTest, GetElementTypeSucceeds) { + constexpr absl::string_view kQuantizeOp = R"mlir( + func.func @quantize(%arg0: tensor<2xf32>) -> tensor<2x!quant.uniform> { + %0 = stablehlo.uniform_quantize %arg0 : (tensor<2xf32>) -> tensor<2x!quant.uniform> + return %0 : tensor<2x!quant.uniform> + } + )mlir"; + + OwningOpRef module_op = ParseModuleOpString(kQuantizeOp); + ASSERT_TRUE(module_op); + + auto func_op = module_op->lookupSymbol("quantize"); + ASSERT_THAT(func_op, NotNull()); + + auto uniform_quantize_op = + *func_op.getOps<::mlir::stablehlo::UniformQuantizeOp>().begin(); + Value result = uniform_quantize_op.getResult(); + EXPECT_THAT(GetElementType(result), NotNull()); +} + } // namespace } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index c6be5b32248221..f6b1d8ac9f3493 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -48,7 +48,9 @@ cc_library( srcs = [ "passes/convert_func_to_bfloat16.cc", "passes/convert_xla_call_module_op_to_bfloat16.cc", - "passes/fold_constant_transpose_pass.cc", + "passes/defer_activation_transpose.cc", + "passes/fold_constant_transpose.cc", + "passes/insert_weight_param.cc", "passes/lift_quantizable_spots_as_functions.cc", "passes/lift_quantizable_spots_as_functions_fusion.inc", "passes/lift_quantizable_spots_as_functions_simple.inc", @@ -56,7 +58,6 @@ cc_library( "passes/optimize_graph.cc", "passes/post_quantize.cc", "passes/prepare_quantize.cc", - "passes/prepare_quantize_hybrid.cc", "passes/quantize.cc", "passes/quantize_composite_functions.cc", "passes/quantize_weight.cc", @@ -93,6 +94,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:permutation", "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", @@ -112,6 +114,7 @@ cc_library( "//tensorflow/lite/kernels:padding", "//tensorflow/lite/kernels/internal:quantization_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/random", @@ -154,8 +157,10 @@ cc_library( ], compatible_with = get_compatible_with_portable(), deps = [ + ":quantization_config_proto_cc", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:lift_as_function_call", "//tensorflow/compiler/mlir/quantization/common:uniform_quantized_types", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", @@ -166,6 +171,8 @@ cc_library( "//tensorflow/core/platform:path", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -327,10 +334,10 @@ cc_library( "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/mlir_hlo", - "@local_xla//xla/mlir_hlo:chlo_legalize_to_hlo", "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", "@stablehlo//:chlo_ops", @@ -514,6 +521,7 @@ cc_library( ":quantization_config_proto_cc", ":stablehlo_test_passes_inc_gen", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:config", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:post_calibration", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:pre_calibration", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", @@ -524,7 +532,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -595,6 +602,7 @@ tf_cc_test( deps = [ ":stablehlo_type_utils", "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@stablehlo//:stablehlo_ops", @@ -756,6 +764,7 @@ tf_cc_binary( ":test_passes", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:pass_pipeline", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD index 7a36ad58dc34a4..77629c7719bf44 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -111,13 +111,32 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":graph_def", - "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", - "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", - "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", + "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "permutation", + hdrs = ["permutation.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "permutation_test", + srcs = ["permutation_test.cc"], + deps = [ + ":permutation", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Support", ], ) @@ -127,10 +146,16 @@ cc_library( hdrs = ["saved_model_export.h"], compatible_with = get_compatible_with_portable(), deps = [ + ":io", ":pass_pipeline", + ":saved_model_import", ":types", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:convert_asset_args", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:unfreeze_constants", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:export_graphdef", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", @@ -140,6 +165,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -150,6 +176,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], @@ -184,15 +211,26 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":types", + "//tensorflow/cc/saved_model:loader", "//tensorflow/cc/saved_model:reader", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", + "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -261,6 +299,7 @@ tf_cc_test( name = "pre_calibration_test", srcs = ["pre_calibration_test.cc"], deps = [ + ":config", ":pre_calibration", "//tensorflow/compiler/mlir/quantization/common:test_base", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", @@ -276,6 +315,26 @@ tf_cc_test( ], ) +cc_library( + name = "report", + srcs = ["report.cc"], + hdrs = ["report.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + ], +) + +tf_cc_test( + name = "report_test", + srcs = ["report_test.cc"], + deps = [ + ":report", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "context", srcs = [], @@ -357,3 +416,36 @@ cc_library( "@local_tsl//tsl/platform:statusor", ], ) + +cc_library( + name = "weight_only_ptq", + srcs = ["weight_only_ptq.cc"], + hdrs = ["weight_only_ptq.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":component", + ":context", + ":pass_pipeline", + ":saved_model_export", + ":saved_model_import", + ":types", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/mlir_hlo:mhlo_passes", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD index 90afbe53209347..5783ffddd4f050 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD @@ -25,14 +25,14 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", - "//tensorflow/compiler/mlir/quantization/stablehlo/cc:graph_def", - "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", - "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -44,28 +44,18 @@ cc_library( deps = [ ":representative_dataset", ":statistics", - "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:component", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:debugger", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:saved_model_export", - "//tensorflow/compiler/mlir/quantization/stablehlo/cc:saved_model_import", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:types", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", - "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", - "//tensorflow/compiler/mlir/quantization/tensorflow/cc:convert_asset_args", - "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", - "//tensorflow/compiler/mlir/quantization/tensorflow/python:unfreeze_constants", - "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", - "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/core/protobuf:for_core_protos_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status", @@ -76,7 +66,6 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc index 494eadc8463143..ba1671ceb696ca 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc @@ -14,17 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h" -#include #include #include #include #include -#include "absl/algorithm/container.h" -#include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "absl/log/die_if_null.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -37,119 +33,32 @@ limitations under the License. #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" -#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace mlir::quant::stablehlo { -namespace { using ::stablehlo::quantization::AddCalibrationStatistics; using ::stablehlo::quantization::CreateRepresentativeDatasetFileMap; +using ::stablehlo::quantization::DisableDebugging; using ::stablehlo::quantization::QuantizationConfig; using ::stablehlo::quantization::RepresentativeDatasetConfig; using ::stablehlo::quantization::io::CreateTmpDir; using ::stablehlo::quantization::io::GetLocalTmpFileName; using ::tensorflow::AssetFileDef; -using ::tensorflow::MLIRImportOptions; -using ::tensorflow::SavedModelBundle; -using ::tensorflow::SavedModelSignatureDefsToMlirImport; using ::tensorflow::SignatureDef; using ::tensorflow::quantization::ExportedModel; -using ::tensorflow::quantization::PreprocessAndFreezeGraph; using ::tensorflow::quantization::PyFunctionLibrary; -using ::tensorflow::quantization::RunPasses; -using ::tensorflow::quantization::UnfreezeConstantsAndSaveVariables; - -using ImportedMlirModuleOp = - std::pair>; - -// Loads a SavedModel at `saved_model_path` and converts it to `mlir::ModuleOp`. -// -// `tags` identify the `tensorflow::MetaGraphDef` to load from the SavedModel. -// Similarly, `signature_keys` identify the functions (`SignatureDef`s) to load -// within the `MetaGraphDef`. `ctx` is the `MLIRContext`, which should outlive -// the returned `ModuleOp`, thus marked with the lifetime bound attribute. -absl::StatusOr SavedModelToMlirModuleOp( - const absl::string_view saved_model_path, - const std::unordered_set& tags, - const std::vector& signature_keys, - MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND) { - MLIRImportOptions import_options; - import_options.upgrade_legacy = true; - import_options.lift_variables = false; - import_options.include_variables_in_initializers = true; - - auto bundle = std::make_unique(); - - // Copy to eliminate the `const` qualifier so that `absl::MakeSpan` can be - // called on it. - std::vector exported_names = signature_keys; - absl::StatusOr> module_op = - SavedModelSignatureDefsToMlirImport(saved_model_path, tags, - absl::MakeSpan(exported_names), &ctx, - import_options, &bundle); - if (!module_op.status().ok()) { - return absl::InternalError(absl::StrCat("Failed to import SavedModel: ", - module_op.status().ToString())); - } - - return std::make_pair(module_op->release(), std::move(bundle)); -} - -// Sets up and runs the passes for exporting `module_op`. The behavior of the -// exporting passes is controlled by `export_opts`. Returns `AssetFileDef`s that -// associate the input arguments of @main and the asset file names. Asset file -// names will be used to feed the corresponding tensors during initialization -// upon model loading. -absl::StatusOr> RunExportPasses( - const ExportOptions& export_opts, MLIRContext& ctx, ModuleOp module_op) { - if (export_opts.unfreeze_constants) { - TF_RETURN_IF_ERROR(UnfreezeConstantsAndSaveVariables( - export_opts.checkpoint_dir, ctx, module_op)); - LOG(INFO) << "Unfrozen constants and saved variables to checkpoint file: " - << export_opts.checkpoint_dir; - } - - if (absl::Status pass_run_status = RunPasses( - /*name=*/ - export_opts.debug_name, - /*add_passes_func=*/ - [dup_constants = export_opts.duplicate_shape_determining_constants]( - PassManager& pm) { AddExportPasses(pm, dup_constants); }, - ctx, module_op); - !pass_run_status.ok()) { - return pass_run_status; - } - - FailureOr> asset_file_defs = - quant::ConvertAssetArgs(module_op); - if (failed(asset_file_defs)) { - return absl::InternalError("Failed to convert asset args."); - } - - return *asset_file_defs; -} - -} // namespace CalibrationComponent::CalibrationComponent( absl::Nonnull ctx, @@ -171,6 +80,13 @@ absl::StatusOr CalibrationComponent::ExportToSavedModel( ModuleOp module_op, const absl::string_view dst_saved_model_path) { TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTmpFileName()); + // Clone ModuleOp and function aliases so changes in this pipeline won't + // be reflected in the original values. + mlir::OwningOpRef cloned_module_ref(module_op.clone()); + + // Disable DumpTensor ops when running calibration. + DisableDebugging(*cloned_module_ref); + // `duplicate_shape_determining_constants = false` because the // resulting graph of this step is not expected to be loaded on TPU. const ExportOptions export_opts = { @@ -179,11 +95,11 @@ absl::StatusOr CalibrationComponent::ExportToSavedModel( /*debug_name=*/absl::StrCat(kName, kExportStepSuffix)}; TF_ASSIGN_OR_RETURN(const SmallVector asset_file_defs, - RunExportPasses(export_opts, *ctx_, module_op)); + RunExportPasses(export_opts, *ctx_, *cloned_module_ref)); TF_ASSIGN_OR_RETURN(ExportedModel exported_model, ConvertMlirModuleToExportedModel( - module_op, checkpoint_dir, function_aliases_, + *cloned_module_ref, checkpoint_dir, function_aliases_, {asset_file_defs.begin(), asset_file_defs.end()})); py_function_lib_->SaveExportedModel(dst_saved_model_path, exported_model, @@ -193,35 +109,6 @@ absl::StatusOr CalibrationComponent::ExportToSavedModel( return exported_model; } -absl::StatusOr CalibrationComponent::ImportCalibratedSavedModel( - const absl::string_view calibrated_saved_model_path) { - // Convert the SavedModelBundle to an MLIR module. - TF_ASSIGN_OR_RETURN(ImportedMlirModuleOp imported_module, - SavedModelToMlirModuleOp(calibrated_saved_model_path, - tags_, signature_keys_, *ctx_)); - ModuleOp module_op = imported_module.first; - - UpdateFunctionAliases(function_aliases_, module_op); - - // Collect the names of the functions that have aliases so that they may not - // be inlined. - absl::flat_hash_set aliased_function_names; - absl::c_for_each(function_aliases_, [&](const auto& aliases) { - return aliased_function_names.insert(aliases.first); - }); - - // Freezing is required again since variables might have been produced - // during the pre-calibration step. `is_inliner_run = false` to prevent the - // functions lifted for quantization from being inlined. - TF_RETURN_IF_ERROR(PreprocessAndFreezeGraph( - /*mlir_dump_file_prefix=*/kName, /*is_inliner_run=*/false, - /*noinline_functions=*/aliased_function_names, module_op, ctx_, - imported_module.second == nullptr ? nullptr - : imported_module.second->GetSession(), - /*run_tf_to_stablehlo=*/false, /*deserialize_xla_call_module=*/true)); - return module_op; -} - absl::StatusOr CalibrationComponent::Run( ModuleOp module_op, const QuantizationConfig& config) { // Exports the pre-calibrated model to SavedModel. @@ -251,23 +138,14 @@ absl::StatusOr CalibrationComponent::Run( /*force_graph_mode_calibration=*/true, representative_dataset_file_map); if (absl::Status status = AddCalibrationStatistics( - *exported_model.mutable_graph_def(), config.calibration_options(), - *py_function_lib_); + module_op, config.calibration_options(), *py_function_lib_); !status.ok()) { LOG(WARNING) << "Some CustomAggregator ops do not have min or max " "values. Parts of the graph are not quantized. " << status; } - // Exports the calibrated model with statistics attached to the graph. - TF_ASSIGN_OR_RETURN(const std::string calibrated_saved_model_path, - CreateTmpDir()); - py_function_lib_->SaveExportedModel(calibrated_saved_model_path, - exported_model, src_saved_model_path_, - tags_, signature_def_map_); - - // Imports the calibrated saved model back to `ModuleOp`. - return ImportCalibratedSavedModel(calibrated_saved_model_path); + return module_op; } } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc index 22160a8820dfcd..39f4ca8449ae05 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc @@ -19,21 +19,19 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/graph.pb.h" namespace stablehlo::quantization { namespace { using ::stablehlo::quantization::CalibrationOptions; -using ::tensorflow::GraphDef; -using ::tensorflow::NodeDef; using ::tensorflow::calibrator::CalibrationStatistics; using ::tensorflow::calibrator::CalibratorSingleton; using ::tensorflow::quantization::PyFunctionLibrary; @@ -41,13 +39,12 @@ using ::tensorflow::quantization::PyFunctionLibrary; } // namespace absl::Status AddCalibrationStatistics( - GraphDef& graph_def, const CalibrationOptions& calibration_options, + mlir::ModuleOp module_op, const CalibrationOptions& calibration_options, const PyFunctionLibrary& py_function_library) { absl::Status status = absl::OkStatus(); - MutateNodeDefs(graph_def, [&py_function_library, &calibration_options, - &status](NodeDef& node_def) { - if (node_def.op() != "CustomAggregator") return; - const std::string& id = node_def.attr().at("id").s(); + module_op.walk([&py_function_library, &calibration_options, + &status](mlir::TF::CustomAggregatorOp aggregator_op) { + mlir::StringRef id = aggregator_op.getId(); std::optional statistics = CalibratorSingleton::GetStatistics(id); if (statistics == std::nullopt) { @@ -63,8 +60,9 @@ absl::Status AddCalibrationStatistics( calibration_options); CalibratorSingleton::ClearData(id); - (*node_def.mutable_attr())["min"].set_f(min_value); - (*node_def.mutable_attr())["max"].set_f(max_value); + mlir::OpBuilder builder(aggregator_op); + aggregator_op->setAttr("min", builder.getF32FloatAttr(min_value)); + aggregator_op->setAttr("max", builder.getF32FloatAttr(max_value)); }); return status; } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h index 0069692381b6d5..9b67f22a2dac72 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h @@ -16,10 +16,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_STATISTICS_H_ #include "absl/status/status.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" -#include "tensorflow/core/framework/graph.pb.h" namespace stablehlo::quantization { @@ -28,7 +27,7 @@ namespace stablehlo::quantization { // respectively. `calibration_options` provides the strategy to retrieve min and // max values. absl::Status AddCalibrationStatistics( - tensorflow::GraphDef& graph_def, + mlir::ModuleOp module_op, const stablehlo::quantization::CalibrationOptions& calibration_options, const tensorflow::quantization::PyFunctionLibrary& py_function_library); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc index 679e1f8754be9b..0f9932d053cb4d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc @@ -14,12 +14,190 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" +#include + +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + namespace stablehlo::quantization { +namespace { + +// Populate `CalibrationOptions` with default fields. +void PopulateDefaultCalibrationOptions(QuantizationConfig& quant_config) { + if (!quant_config.has_calibration_options() || + quant_config.calibration_options().calibration_method() == + CalibrationOptions::CALIBRATION_METHOD_UNSPECIFIED) { + quant_config.mutable_calibration_options()->set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_MIN_MAX); + } + switch (quant_config.calibration_options().calibration_method()) { + case CalibrationOptions::CALIBRATION_METHOD_MIN_MAX: + break; + case CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX: + break; + case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE: + if (quant_config.calibration_options() + .calibration_parameters() + .initial_num_bins() == 0) { + quant_config.mutable_calibration_options() + ->mutable_calibration_parameters() + ->set_initial_num_bins(256); + } + if (quant_config.calibration_options() + .calibration_parameters() + .min_percentile() == 0) { + quant_config.mutable_calibration_options() + ->mutable_calibration_parameters() + ->set_min_percentile(0.001); + } + if (quant_config.calibration_options() + .calibration_parameters() + .max_percentile() == 0) { + quant_config.mutable_calibration_options() + ->mutable_calibration_parameters() + ->set_max_percentile(99.999); + } + break; + case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE: + if (quant_config.calibration_options() + .calibration_parameters() + .initial_num_bins() == 0) { + quant_config.mutable_calibration_options() + ->mutable_calibration_parameters() + ->set_initial_num_bins(256); + } + break; + case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY: + if (quant_config.calibration_options() + .calibration_parameters() + .initial_num_bins() == 0) { + quant_config.mutable_calibration_options() + ->mutable_calibration_parameters() + ->set_initial_num_bins(256); + } + break; + case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC: + if (quant_config.calibration_options() + .calibration_parameters() + .initial_num_bins() == 0) { + quant_config.mutable_calibration_options() + ->mutable_calibration_parameters() + ->set_initial_num_bins(256); + } + break; + default: + break; + } +} + +// Returns a default `QuantizationSpec` for performing static-range PTQ on all +// ops. +// +// In textproto, the spec corresponds to: +// +// { +// {matcher {function_name {regex: ".*"}} +// {method {static_range_ptq {}}} +// } +QuantizationSpec GetDefaultStaticRangePtqSpec(StaticRangePtqPreset preset) { + QuantizationSpec spec{}; + // Default for all ops. + spec.mutable_matcher()->mutable_function_name()->set_regex( + preset.enable_full_int_quantization() ? ".*" : "^.*(conv|dot|gather).*"); + spec.mutable_method()->mutable_static_range_ptq(); + + return spec; +} + +// Returns a `QuantizationSpec` for performing static-range PTQ on the +// convolution quantizable unit family. Enables per-channel quantization for +// weights, on the channel dimension. +// +// In textproto, the spec corresponds to: +// +// { +// {matcher {function_name {regex: "composite_conv.*"}}} +// {method {static_range_ptq +// {input_quantized_types { +// key: 1, +// value {dimension_specs {dimension: 3}}}} +// }} +// } +QuantizationSpec GetStaticRangePtqSpecForConvolution() { + QuantizationSpec spec{}; + + // Matches all convolution quantizable unit family. + spec.mutable_matcher()->mutable_function_name()->set_regex( + "composite_conv.*"); + StaticRangePtq& static_range_ptq_spec = + *spec.mutable_method()->mutable_static_range_ptq(); + + // Enable per-channel quantization for convolution weights. + QuantizedType conv_weight_quantized_type{}; + + // Assumes NHWC format, specifying the channel dimension (3) as the + // quantized axis. + conv_weight_quantized_type.mutable_dimension_specs()->set_dimension(3); + + // The index of weight operands passed to lifted functions for convolution + // is 1. + static_range_ptq_spec.mutable_input_quantized_types()->try_emplace( + 1, std::move(conv_weight_quantized_type)); + + return spec; +}; + +void ExpandStaticRangePtqPreset(const StaticRangePtqPreset& preset, + QuantizationConfig& config) { + // Populate with preset's representative dataset configs if the user didn't + // explicitly specify other representative dataset configs to the top-level + // `CalibrationOptions`. + if (config.calibration_options().representative_datasets().empty()) { + auto preset_datasets = preset.representative_datasets(); + config.mutable_calibration_options() + ->mutable_representative_datasets() + ->Add(preset_datasets.begin(), preset_datasets.end()); + } + + // Create a new `QuantizationSpecs` to replace the existing one. The + // expansion from `StaticRangePtqPreset` gets populated first and then + // user-provided explicit `QuantizationSpec`s will be appended. + QuantizationSpecs new_specs{}; + *new_specs.add_specs() = + GetDefaultStaticRangePtqSpec(/*preset=*/config.static_range_ptq_preset()); + *new_specs.add_specs() = GetStaticRangePtqSpecForConvolution(); + + // Append user-provided specs to override existing specs. + const QuantizationSpecs& previous_specs = config.specs(); + new_specs.mutable_specs()->Add(previous_specs.specs().begin(), + previous_specs.specs().end()); + + config.mutable_specs()->Swap(&new_specs); +} + +} // namespace + +QuantizationConfig ExpandPresets(const QuantizationConfig& config) { + QuantizationConfig new_config = config; + + // Update the `new_config` with each preset's expansions. + switch (config.preset_case()) { + case QuantizationConfig::kStaticRangePtqPreset: + ExpandStaticRangePtqPreset(config.static_range_ptq_preset(), new_config); + break; + default: + // Preset has not been specified. The expansion is a no-op. + break; + } + + return new_config; +} QuantizationConfig PopulateDefaults( const QuantizationConfig& user_provided_config) { QuantizationConfig config = user_provided_config; + PopulateDefaultCalibrationOptions(config); + PipelineConfig& pipeline_config = *config.mutable_pipeline_config(); if (!pipeline_config.has_unpack_quantized_types()) { pipeline_config.set_unpack_quantized_types(true); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h index 20b9efa4a60fa0..5dc4554d784c92 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h @@ -24,6 +24,23 @@ namespace stablehlo::quantization { QuantizationConfig PopulateDefaults( const QuantizationConfig& user_provided_config); +// Returns a copy of `QuantizationConfig` where presets are expanded and +// transformed into other fields in `QuantizationConfig`. +// +// The expansion rules are as follows: +// * StaticRangePtqPreset +// - The preset's `representative_datasets` field will be transferred to +// `QuantizationConfig.calibration_options.representative_datasets`, unless +// the user explicitly provided representative dataset configs to +// `calibration_options`. In that case, the explicit configs take precedence +// and the preset's configs are ignored. +// - For `QuantizationSpecs`, the expanded `QuantizationSpec`s will be +// populated first and user-provided `QuantizationSpec`s, if any, will be +// appended. This expresses the fact that user-provided specs take precedence. +// * Preset unspecified +// - No-op. +QuantizationConfig ExpandPresets(const QuantizationConfig& config); + } // namespace stablehlo::quantization #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc index 5912788bddf96b..e3f2bfde3d10c3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc @@ -14,12 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" +#include #include #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" namespace stablehlo::quantization { namespace { +using ::testing::Eq; +using ::testing::SizeIs; +using ::testing::StrEq; + TEST(PopulateDefaultsTest, PopulateDefaultsForEmptyConfig) { QuantizationConfig config{}; @@ -37,5 +42,237 @@ TEST(PopulateDefaultsTest, PopulateDefaultsForConfigWithUnpackQuantizedTypes) { EXPECT_FALSE(new_config.pipeline_config().unpack_quantized_types()); } +TEST(PopulateDefaultsTest, DefaultCalibrationOptionsPopulated) { + QuantizationConfig config{}; + + const QuantizationConfig new_config = PopulateDefaults(config); + EXPECT_THAT(new_config.calibration_options().calibration_method(), + Eq(CalibrationOptions::CALIBRATION_METHOD_MIN_MAX)); +} + +TEST(PopulateDefaultsTest, + DefaultCalibrationOptionsPopulatedForUnspecifiedMethod) { + QuantizationConfig config{}; + CalibrationOptions& calibration_options = + *config.mutable_calibration_options(); + calibration_options.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_UNSPECIFIED); + + const QuantizationConfig new_config = PopulateDefaults(config); + EXPECT_THAT(new_config.calibration_options().calibration_method(), + Eq(CalibrationOptions::CALIBRATION_METHOD_MIN_MAX)); +} + +TEST(PopulateDefaultsTest, ExplicitCalibrationOptionsNotOverridden) { + QuantizationConfig config{}; + CalibrationOptions& calibration_options = + *config.mutable_calibration_options(); + calibration_options.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX); + calibration_options.mutable_calibration_parameters()->set_initial_num_bins( + 512); + + // Test that if the user explicitly provided `calibration_options`, it is not + // overridden. + const QuantizationConfig new_config = PopulateDefaults(config); + EXPECT_THAT(new_config.calibration_options().calibration_method(), + Eq(CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX)); + EXPECT_THAT(new_config.calibration_options() + .calibration_parameters() + .initial_num_bins(), + Eq(512)); +} + +TEST(PopulateDefaultsTest, DefaultNumbersPopulatedForPartOfCalibrationOptions) { + QuantizationConfig config{}; + CalibrationOptions& calibration_options = + *config.mutable_calibration_options(); + calibration_options.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE); + calibration_options.mutable_calibration_parameters()->set_initial_num_bins( + 512); + + // Test that if the user explicitly provided part of the + // `calibration_options`, it is not overridden, rest of the data are default. + const QuantizationConfig new_config = PopulateDefaults(config); + EXPECT_THAT(new_config.calibration_options().calibration_method(), + Eq(CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE)); + EXPECT_THAT(new_config.calibration_options() + .calibration_parameters() + .initial_num_bins(), + Eq(512)); + EXPECT_THAT(new_config.calibration_options() + .calibration_parameters() + .min_percentile(), + Eq(0.001f)); + EXPECT_THAT(new_config.calibration_options() + .calibration_parameters() + .max_percentile(), + Eq(99.999f)); +} + +TEST(PopulateDefaultsTest, + DefaultNumbersPopulatedForCalibrationOptionsOfHistogramMseBruteforce) { + QuantizationConfig config{}; + CalibrationOptions& calibration_options = + *config.mutable_calibration_options(); + calibration_options.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE); + + const QuantizationConfig new_config = PopulateDefaults(config); + EXPECT_THAT( + new_config.calibration_options().calibration_method(), + Eq(CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE)); + EXPECT_THAT(new_config.calibration_options() + .calibration_parameters() + .initial_num_bins(), + Eq(256)); + EXPECT_THAT(new_config.calibration_options() + .calibration_parameters() + .min_percentile(), + Eq(0.0f)); + EXPECT_THAT(new_config.calibration_options() + .calibration_parameters() + .max_percentile(), + Eq(0.0f)); +} + +TEST(ExpandPresetsTest, ExpandUnspecifiedPreset) { + QuantizationConfig config{}; + const QuantizationConfig new_config = ExpandPresets(config); + + // Test that nothing has been changed. + EXPECT_FALSE(new_config.has_specs()); + EXPECT_FALSE(new_config.has_calibration_options()); + EXPECT_FALSE(new_config.has_pipeline_config()); +} + +TEST(ExpandPresetsTest, ExpandStaticRangePtqEnableFullIntquantization) { + QuantizationConfig config{}; + RepresentativeDatasetConfig& preset_dataset_config = + *config.mutable_static_range_ptq_preset()->add_representative_datasets(); + config.mutable_static_range_ptq_preset()->set_enable_full_int_quantization( + true); + preset_dataset_config.mutable_tf_record()->set_path("/test/path"); + + const QuantizationConfig new_config = ExpandPresets(config); + ASSERT_THAT(new_config.specs().specs(), SizeIs(2)); + + const QuantizationSpec& default_spec = new_config.specs().specs(0); + EXPECT_THAT(default_spec.matcher().function_name().regex(), StrEq(".*")); + EXPECT_TRUE(default_spec.method().has_static_range_ptq()); + + // Test that the expansion for convolution ops is done. + const QuantizationSpec& conv_spec = new_config.specs().specs(1); + EXPECT_THAT(conv_spec.matcher().function_name().regex(), + StrEq("composite_conv.*")); + ASSERT_TRUE(conv_spec.method().has_static_range_ptq()); + + const StaticRangePtq& srq_spec = conv_spec.method().static_range_ptq(); + ASSERT_THAT(srq_spec.input_quantized_types(), SizeIs(1)); + ASSERT_TRUE(srq_spec.input_quantized_types().contains(1)); + + EXPECT_THAT( + srq_spec.input_quantized_types().at(1).dimension_specs().dimension(), + Eq(3)); + + // Test that representative dataset config has been transferred to the + // `CalibrationOptions`. + ASSERT_THAT(new_config.calibration_options().representative_datasets(), + SizeIs(1)); + EXPECT_THAT(new_config.calibration_options() + .representative_datasets(0) + .tf_record() + .path(), + StrEq("/test/path")); +} + +TEST(ExpandPresetsTest, ExpandStaticRangePtqPresetDefault) { + QuantizationConfig config{}; + RepresentativeDatasetConfig& preset_dataset_config = + *config.mutable_static_range_ptq_preset()->add_representative_datasets(); + preset_dataset_config.mutable_tf_record()->set_path("/test/path"); + + const QuantizationConfig new_config = ExpandPresets(config); + ASSERT_THAT(new_config.specs().specs(), SizeIs(2)); + + const QuantizationSpec& spec = new_config.specs().specs(0); + EXPECT_THAT(spec.matcher().function_name().regex(), + StrEq("^.*(conv|dot|gather).*")); + EXPECT_TRUE(spec.method().has_static_range_ptq()); +} + +TEST(ExpandPresetsTest, + ExpandStaticRangePtqPresetWithTopLevelRepresentativeDataset) { + // Test the scenario where both + // `config.calibration_options.representative_datasets` and + // `config.static_range_ptq_preset.representative_datasets` are both + // specified. In this case, the one set to the `calibration_options` takes + // precedence. + QuantizationConfig config{}; + RepresentativeDatasetConfig& top_level_dataset_config = + *config.mutable_calibration_options()->add_representative_datasets(); + top_level_dataset_config.mutable_tf_record()->set_path("/test/path/1"); + + RepresentativeDatasetConfig& preset_dataset_config = + *config.mutable_static_range_ptq_preset()->add_representative_datasets(); + preset_dataset_config.mutable_tf_record()->set_path("/test/path/2"); + + const QuantizationConfig new_config = ExpandPresets(config); + + // Test that representative dataset config has not been transferred to the + // `CalibrationOptions`. Top-level config takes precedence. + ASSERT_THAT(new_config.calibration_options().representative_datasets(), + SizeIs(1)); + EXPECT_THAT(new_config.calibration_options() + .representative_datasets(0) + .tf_record() + .path(), + StrEq("/test/path/1")); +} + +TEST(ExpandPresetsTest, ExpandStaticRangePtqPresetThenAppendExplicitSpecs) { + QuantizationConfig config{}; + config.mutable_static_range_ptq_preset()->set_enable_full_int_quantization( + true); + + QuantizationSpec& user_provided_spec = *config.mutable_specs()->add_specs(); + user_provided_spec.mutable_matcher()->mutable_function_name()->set_regex( + "composite_dot_general_fn_1"); + user_provided_spec.mutable_method()->mutable_no_quantization(); + + // Test that the expanded `QuantizationSpec`s are populated first and then + // user-provided specs are appended. + // + // It should look like: + // + // specs {matcher {function_name {regex: ".*"}} method {static_range_ptq {}}} + // specs { + // matcher {function_name {regex: "composite_conv.*"}} + // method {static_range_ptq {...}}} + // } + // specs { + // matcher {function_name {regex: "composite_dot_general_fn_1"}} + // method {no_quantization {}} + // } + const QuantizationConfig new_config = ExpandPresets(config); + ASSERT_THAT(new_config.specs().specs(), SizeIs(3)); + + const QuantizationSpec& first_spec = new_config.specs().specs(0); + EXPECT_THAT(first_spec.matcher().function_name().regex(), StrEq(".*")); + EXPECT_TRUE(first_spec.method().has_static_range_ptq()); + + const QuantizationSpec& second_spec = new_config.specs().specs(1); + EXPECT_THAT(second_spec.matcher().function_name().regex(), + StrEq("composite_conv.*")); + EXPECT_TRUE(second_spec.method().has_static_range_ptq()); + + // This corresponds to `user_provided_spec`. + const QuantizationSpec& third_spec = new_config.specs().specs(2); + EXPECT_THAT(third_spec.matcher().function_name().regex(), + StrEq("composite_dot_general_fn_1")); + EXPECT_TRUE(third_spec.method().has_no_quantization()); +} + } // namespace } // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc index 1ba51790de0ac9..a06c7f8ed79fb4 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc @@ -14,61 +14,34 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h" -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/protobuf/meta_graph.pb.h" namespace stablehlo::quantization { -namespace { -using ::tensorflow::NodeDef; -using ::tensorflow::SignatureDef; -using ::tensorflow::quantization::DebuggerOptions; -using ::tensorflow::quantization::ExportedModel; -using ::tensorflow::quantization::PyFunctionLibrary; +void DisableDebugging(mlir::ModuleOp module_op) { + module_op.walk( + [](mlir::TF::DumpTensorOp dump_op) { dump_op.setEnabled(false); }); +} -} // namespace +void EnableDebugging(tensorflow::quantization::ExportedModel& exported_model) { + MutateNodeDefs(*exported_model.mutable_graph_def(), + [](tensorflow::NodeDef& node_def) { + if (node_def.op() == "DumpTensor") { + (*node_def.mutable_attr())["enabled"].set_b(true); + } + }); +} -void EnableDebugging( - ExportedModel& exported_model, const DebuggerOptions& debugger_options, - const PyFunctionLibrary& py_function_library, - const absl::string_view src_saved_model_path, - const std::unordered_set& tags, - const absl::flat_hash_map& signature_def_map) { - // Enable `DumpTensor` nodes in `graph_def`. DumpTensor is disabled by - // default to avoid logging data during calibration. - MutateNodeDefs(*exported_model.mutable_graph_def(), [](NodeDef& node_def) { - if (node_def.op() == "DumpTensor") { - (*node_def.mutable_attr())["enabled"].set_b(true); - } +void ChangeToQuantizedFilename(mlir::ModuleOp module_op) { + module_op.walk([](mlir::TF::DumpTensorOp dump_op) { + dump_op.setFileName("quantized_tensor_data.pb"); }); - - if (debugger_options.debugger_type() == - DebuggerConfig::DEBUGGER_TYPE_WHOLE_MODEL) { - // TODO: b/295139417 - Remove CustomAggregator op in unquantized dump model. - // TODO: b/296916287 - Create a separate function for saving unquantized - // dump model. - py_function_library.SaveExportedModel( - debugger_options.unquantized_dump_model_path(), exported_model, - src_saved_model_path, tags, signature_def_map); - - // Update the `DumpTensor` ops' file name in `graph_def`. - MutateNodeDefs(*exported_model.mutable_graph_def(), [](NodeDef& node_def) { - if (node_def.op() == "DumpTensor") { - (*node_def.mutable_attr())["file_name"].set_s( - "quantized_tensor_data.pb"); - } - }); - } } } // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h index 6bb427ecbdf1fd..f034e4d94ee4bf 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h @@ -15,35 +15,20 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_DEBUGGER_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_DEBUGGER_H_ -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" -#include "tensorflow/core/protobuf/meta_graph.pb.h" namespace stablehlo::quantization { -// Enables debugging on `exported_model` by updating the `DumpTensor` ops. -// -// Saves the current model to `debugger_options.unquantized_dump_model_path()` -// if the debugger type is `DEBUGGER_TYPE_WHOLE_MODEL`. This is required because -// in whole-model debugging mode the `DumpTensor` ops for the unquantized -// tensors are only inserted in the unquantized model whereas `DumpTensor` ops -// for the quantized tensors are only inserted in the quantized model. Both -// models are required to be able to dump both quantized and unquantized tensors -// and compare them offline. -void EnableDebugging( - tensorflow::quantization::ExportedModel& exported_model, - const tensorflow::quantization::DebuggerOptions& debugger_options, - const tensorflow::quantization::PyFunctionLibrary& py_function_library, - absl::string_view src_saved_model_path, - const std::unordered_set& tags, - const absl::flat_hash_map& - signature_def_map); +// Disables debugging on `DumpTensor` ops. +void DisableDebugging(mlir::ModuleOp module_op); + +// Enables debugging on `DumpTensor` ops. +void EnableDebugging(tensorflow::quantization::ExportedModel& exported_model); + +// Changes the filename from `unquantized_tensor_data.pb` to +// `quantized_tensor_data.pb`. +void ChangeToQuantizedFilename(mlir::ModuleOp module_op); } // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc index 59e64d6d77d95e..ebe950c58142f6 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc @@ -17,6 +17,7 @@ limitations under the License. #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" @@ -37,14 +38,11 @@ void AddPreCalibrationPasses(OpPassManager& pm, const CalibrationOptions& calibration_options, const QuantizationSpecs& quantization_specs, const DebuggerConfig& debugger_config) { - // For models with NCHW convolution format. This pass is required because - // downstream pipeline handles NHWC convolution better for most cases. - pm.addNestedPass(createNchwConvolutionToNhwcPass()); + // Convert NCHW tensors to NHWC at along with extra optimizations as + // downstream passes perform better optimizations when dealing with NHWC + // formatted tensors. + AddProcessNchwTensorPasses(pm); - // Folds `stablehlo.constant`->`stablehlo.transpose` patterns, which is often - // generated as by-products after optimizing dimension numbers (e.g. - // NCHW->NHWC convolution conversion). - pm.addNestedPass(createFoldConstantTransposePass()); pm.addPass(CreateLiftQuantizableSpotsAsFunctionsPass(quantization_specs)); if (debugger_config.debugger_type() != DebuggerConfig::DEBUGGER_TYPE_UNSPECIFIED) { @@ -60,11 +58,16 @@ void AddPostCalibrationPasses( OpPassManager& pm, const PipelineConfig& pipeline_config, const StaticRangePtqPreset& static_range_ptq_preset) { QuantizeCompositeFunctionsPassOptions options; + // TODO: b/331120943 - Use QuantizationConfig instead of preset flags. options.enable_per_channel_quantized_weight_ = static_range_ptq_preset.enable_per_channel_quantized_weight(); + options.enable_full_int_quantization_ = + static_range_ptq_preset.enable_full_int_quantization(); // For debugging purposes. options.mlir_dump_file_name_ = "quantize_composite_functions"; options.enable_weight_only_ = false; + + AddShapeLegalizationPasses(pm); pm.addNestedPass( CreateConvertCustomAggregationOpToQuantStatsPass()); pm.addPass(createQuantizeCompositeFunctionsPass(options)); @@ -75,6 +78,38 @@ void AddPostCalibrationPasses( } } +void AddWeightOnlyQuantizationPasses( + OpPassManager& pm, const QuantizationSpecs& quantization_specs, + const PipelineConfig& pipeline_config, + const DebuggerConfig& debugger_config) { + // For models with NCHW convolution format. This pass is required because + // downstream pipeline handles NHWC convolution better for most cases. + pm.addNestedPass(createNchwConvolutionToNhwcPass()); + + // Folds `stablehlo.constant`->`stablehlo.transpose` patterns, which is often + // generated as by-products after optimizing dimension numbers (e.g. + // NCHW->NHWC convolution conversion). + pm.addNestedPass(createFoldConstantTransposePass()); + pm.addPass(CreateLiftQuantizableSpotsAsFunctionsPass(quantization_specs)); + if (debugger_config.debugger_type() != + DebuggerConfig::DEBUGGER_TYPE_UNSPECIFIED) { + pm.addPass(CreateAddDumpTensorOpPass(debugger_config.debugger_type(), + debugger_config.log_dir_path())); + } + AddShapeLegalizationPasses(pm); + QuantizeCompositeFunctionsPassOptions options; + // For debugging purposes. + options.mlir_dump_file_name_ = "quantize_composite_functions"; + options.enable_weight_only_ = true; + pm.addPass(createQuantizeCompositeFunctionsPass(options)); + + // Add an inliner pass to inline quantized StableHLO functions. + pm.addPass(createInlinerPass()); + if (pipeline_config.unpack_quantized_types()) { + AddStablehloQuantToIntPasses(pm); + } +} + void AddXlaCallModuleOpDeserializationPasses(OpPassManager& pm) { pm.addPass(TF::CreateXlaCallModuleDeserializationPass()); pm.addPass(createRestoreFunctionNamePass()); @@ -119,4 +154,26 @@ void AddCallModuleSerializationPasses(OpPassManager& pm) { pm.addPass(TF::CreateXlaCallModuleSerializationPass()); } +void AddProcessNchwTensorPasses(OpPassManager& pm) { + // For models with NCHW convolution format. This pass is required because + // downstream pipeline handles NHWC convolution better for most cases. + pm.addNestedPass(createNchwConvolutionToNhwcPass()); + + // Recursively push down the `stablehlo.transpose` ops for activations + // generated by the `NchwConvolutionToNhwc` pass. + pm.addNestedPass(createDeferActivationTransposePass()); + + // Folds `stablehlo.constant`->`stablehlo.transpose` patterns, which is often + // generated as by-products after optimizing dimension numbers (e.g. + // NCHW->NHWC convolution conversion). + pm.addNestedPass(createFoldConstantTransposePass()); +} + +void RegisterPassPipelines() { + static PassPipelineRegistration<> nchw_tensor_format_processing_pipeline( + /*arg=*/"stablehlo-process-nchw-tensor", + /*description=*/"Optimizes tensors with NCHW format.", + AddProcessNchwTensorPasses); +} + } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h index ef7b51aaf6096f..4f94506b6c184e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h @@ -37,6 +37,13 @@ void AddPostCalibrationPasses( const ::stablehlo::quantization::StaticRangePtqPreset& static_range_ptq_preset); +// Adds passes for weight-only quantization. +void AddWeightOnlyQuantizationPasses( + OpPassManager& pm, + const ::stablehlo::quantization::QuantizationSpecs& quantization_specs, + const ::stablehlo::quantization::PipelineConfig& pipeline_config, + const ::stablehlo::quantization::DebuggerConfig& debugger_config); + // Deserializes StableHLO functions serialized and embedded in XlaCallModuleOps. void AddXlaCallModuleOpDeserializationPasses(OpPassManager& pm); @@ -54,6 +61,16 @@ void AddCallModuleSerializationPasses(OpPassManager& pm); // through a StableHLO <-> MHLO roundtrip to utilize the MHLOQuantToInt pass. void AddStablehloQuantToIntPasses(OpPassManager& pm); +// Processes tensors with NCHW format (== (batch, channel, height, weight)) by +// converting them to NHWC formats along with extra optimizations such as +// constant folding the transpose->convolution pattern. This is useful when +// downstream pipeline (e.g. XLA) is more optimized when accepting NHWC formats. +void AddProcessNchwTensorPasses(OpPassManager& pm); + +// Registers quantization pass pipelines. This is only required when running +// MLIR opt binaries and not required when adding passes programmatically. +void RegisterPassPipelines(); + } // namespace mlir::quant::stablehlo #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PASS_PIPELINE_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h new file mode 100644 index 00000000000000..35b1082b10dae9 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PERMUTATION_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PERMUTATION_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" // IWYU pragma: keep; required to include the definition of ArrayRef +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" // IWYU pragma: keep; required to include the definition of SmallVector +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir::quant { + +// Permutes `values` with `permutation`. Returns the permuted values. Sizes of +// `values` and `permutation` must be equal, and the elements of `permutation` +// should be less than `values.size()`. +template , void>> +SmallVector Permute(const ArrayRef values, + const ArrayRef permutation) { + SmallVector permuted_values(/*Size=*/values.size(), /*Value=*/T{}); + for (auto [i, permutation_idx] : llvm::enumerate(permutation)) { + permuted_values[i] = std::move(values[permutation_idx]); + } + return permuted_values; +} + +} // namespace mlir::quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PERMUTATION_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation_test.cc new file mode 100644 index 00000000000000..27a7886ba38466 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation_test.cc @@ -0,0 +1,64 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h" + +#include +#include + +#include +#include +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir::quant { +namespace { + +using testing::ElementsAre; +using testing::IsEmpty; + +TEST(PermutationTest, PermuteEmptyArray) { + const SmallVector permutation_result = + Permute(SmallVector{}, SmallVector{}); + EXPECT_THAT(permutation_result, IsEmpty()); +} + +TEST(PermutationTest, PermuteOneElement) { + const SmallVector single_element_array = {8}; + const SmallVector permutation = {0}; + + const SmallVector permutation_result = + Permute(single_element_array, permutation); + EXPECT_THAT(permutation_result, ElementsAre(8)); +} + +TEST(PermutationTest, PermuteFourElements) { + const SmallVector arr = {0, 3, 1, 2}; + // Permutation inverse of {0, 3, 1, 2}. + const SmallVector permutation = {0, 2, 3, 1}; + + const SmallVector permutation_result = Permute(arr, permutation); + EXPECT_THAT(permutation_result, ElementsAre(0, 1, 2, 3)); +} + +TEST(PermutationTest, PermuteFourStringElements) { + const SmallVector arr = {"a", "b", "c", "d"}; + const SmallVector permutation = {0, 2, 3, 1}; + + const SmallVector permutation_result = + Permute(arr, permutation); + EXPECT_THAT(permutation_result, ElementsAre("a", "c", "d", "b")); +} + +} // namespace +} // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration_test.cc index c17c39d8783ba8..3d4d2295455a5c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/test_base.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" @@ -34,6 +35,8 @@ limitations under the License. namespace mlir::quant::stablehlo { namespace { +using ::stablehlo::quantization::ExpandPresets; +using ::stablehlo::quantization::PopulateDefaults; using ::stablehlo::quantization::QuantizationConfig; using ::testing::Contains; using ::testing::SizeIs; @@ -92,8 +95,11 @@ TEST_F(PreCalibrationComponentTest, )mlir"); ASSERT_TRUE(module_op); + QuantizationConfig quantization_config{}; + quantization_config.mutable_static_range_ptq_preset(); + quantization_config = ExpandPresets(PopulateDefaults(quantization_config)); absl::StatusOr pre_calibration_result = - component.Run(*module_op, QuantizationConfig()); + component.Run(*module_op, quantization_config); EXPECT_THAT(pre_calibration_result, IsOk()); diff --git a/third_party/xla/xla/python/xla_extension.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc similarity index 54% rename from third_party/xla/xla/python/xla_extension.cc rename to tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc index 5adc194d65f054..ef24c16dbf4acc 100644 --- a/third_party/xla/xla/python/xla_extension.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The OpenXLA Authors. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,10 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h" -#include "pybind11/pybind11.h" // from @pybind11 -#include "xla/python/xla.h" +#include -extern "C" PYBIND11_EXPORT PyObject *PyInit_xla_extension() { - return xla::InitializeXlaExtension(); +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::quant::stablehlo { + +using ::stablehlo::quantization::QuantizationResult; + +void QuantizationReport::AddQuantizationResult(QuantizationResult&& result) { + *quantization_results_.add_results() = std::move(result); } + +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h new file mode 100644 index 00000000000000..94eb47463f16c1 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h @@ -0,0 +1,48 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_REPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_REPORT_H_ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::quant::stablehlo { + +// A class that manages information about `QuantizableUnit`s post-quantization, +// internally in the form of `QuantizationUnits`. It is used to collect +// quantization summary from a quantized `ModuleOp` and emit it in a human- and +// machine-readable format. +class QuantizationReport { + public: + QuantizationReport() = default; + + // Adds a `QuantizationResult` to the report. + void AddQuantizationResult( + ::stablehlo::quantization::QuantizationResult&& result); + + // Returns `QuantizationResults` that are registered in this report. + const ::stablehlo::quantization::QuantizationResults& GetQuantizationResults() + const { + return quantization_results_; + } + + private: + // Quantization results that are registered in this report. A quantization + // result may be added manually by calling `AddQuantizationResult`. + ::stablehlo::quantization::QuantizationResults quantization_results_; +}; + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_REPORT_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc new file mode 100644 index 00000000000000..f6897f7fde401d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc @@ -0,0 +1,64 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h" + +#include + +#include +#include +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::quant::stablehlo { +namespace { + +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::QuantizableUnit; +using ::stablehlo::quantization::QuantizationResult; +using ::stablehlo::quantization::QuantizationResults; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::StrEq; + +TEST(QuantizationReportTest, GetQuantizationResultsReturnsEmptyResults) { + QuantizationReport report{}; + + const QuantizationResults& results = report.GetQuantizationResults(); + ASSERT_THAT(results.results(), IsEmpty()); +} + +TEST(QuantizationReportTest, AddQuantizationResult) { + // Construct a `QuantizationResult` to add, representing a unit named + // `quantized_my_function` that is not quantized. + QuantizationResult result{}; + QuantizableUnit& quantizable_unit = *result.mutable_quantizable_unit(); + quantizable_unit.set_name("quantized_my_function"); + + Method& method = *result.mutable_method(); + method.mutable_no_quantization(); + + QuantizationReport report{}; + report.AddQuantizationResult(std::move(result)); + + const QuantizationResults& results = report.GetQuantizationResults(); + ASSERT_THAT(results.results(), SizeIs(1)); + + const QuantizationResult& first_result = results.results(0); + EXPECT_THAT(first_result.quantizable_unit().name(), + StrEq("quantized_my_function")); + EXPECT_TRUE(first_result.method().has_no_quantization()); +} + +} // namespace +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.cc index 7945ddf712209a..fd85bceca6f9c2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.cc @@ -17,10 +17,12 @@ limitations under the License. #include #include #include +#include #include #include #include "absl/algorithm/container.h" +#include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" @@ -32,11 +34,17 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/constants.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" @@ -56,6 +64,8 @@ namespace { using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; using ::mlir::tf_saved_model::kTfSavedModelInitializerInitType; using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; +using ::stablehlo::quantization::QuantizationConfig; +using ::stablehlo::quantization::io::GetLocalTmpFileName; using ::tensorflow::AssetFileDef; using ::tensorflow::ConvertMlirToGraph; using ::tensorflow::FunctionDefLibrary; @@ -67,6 +77,8 @@ using ::tensorflow::NodeDef; using ::tensorflow::OpRegistry; using ::tensorflow::SaverDef; using ::tensorflow::quantization::ExportedModel; +using ::tensorflow::quantization::RunPasses; +using ::tensorflow::quantization::UnfreezeConstantsAndSaveVariables; // Finds and returns the name of the node from a set of control output nodes. // The name should contain the string `contains`. Returns an empty string if no @@ -114,7 +126,29 @@ std::string FindFilePrefixTensorName(const GraphDef& graph_def) { } // namespace -ExportedModel CreateExportedModel( +absl::StatusOr CreateExportedModel( + const std::vector& signature_keys, + const std::unordered_set& tags, + const QuantizationConfig& quantization_config, + absl::string_view debug_name_prefix, + const absl::flat_hash_map& function_aliases, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND, ModuleOp module_op) { + TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTmpFileName()); + const ExportOptions export_opts = { + /*duplicate_shape_determining_constants=*/true, + /*unfreeze_constants=*/false, checkpoint_dir, + /*debug_name=*/ + absl::StrCat(debug_name_prefix, kExportStepSuffix)}; + + TF_ASSIGN_OR_RETURN(const SmallVector asset_file_defs, + RunExportPasses(export_opts, ctx, module_op)); + + return ConvertMlirModuleToExportedModel( + module_op, checkpoint_dir, function_aliases, + {asset_file_defs.begin(), asset_file_defs.end()}); +} + +ExportedModel CreateExportedModelFromGraphDef( GraphDef&& graph_def, const absl::string_view init_node_name, const absl::string_view checkpoint_dir, const std::optional saver_def, @@ -222,9 +256,35 @@ absl::StatusOr ConvertMlirModuleToExportedModel( TF_ASSIGN_OR_RETURN(const std::optional saver_def, CreateSaverDef(control_ret_node_names, graph_def)); - return CreateExportedModel(std::move(graph_def), init_node_name, - checkpoint_dir, std::move(saver_def), - function_aliases, asset_file_defs); + return CreateExportedModelFromGraphDef(std::move(graph_def), init_node_name, + checkpoint_dir, std::move(saver_def), + function_aliases, asset_file_defs); +} + +absl::StatusOr> RunExportPasses( + const ExportOptions& export_opts, MLIRContext& ctx, ModuleOp module_op) { + if (export_opts.unfreeze_constants) { + TF_RETURN_IF_ERROR(UnfreezeConstantsAndSaveVariables( + export_opts.checkpoint_dir, ctx, module_op)); + LOG(INFO) << "Unfrozen constants and saved variables to checkpoint file: " + << export_opts.checkpoint_dir; + } + + TF_RETURN_IF_ERROR(RunPasses( + /*name=*/ + export_opts.debug_name, + /*add_passes_func=*/ + [dup_constants = export_opts.duplicate_shape_determining_constants]( + PassManager& pm) { AddExportPasses(pm, dup_constants); }, + ctx, module_op)); + + FailureOr> asset_file_defs = + quant::ConvertAssetArgs(module_op); + if (failed(asset_file_defs)) { + return absl::InternalError("Failed to convert asset args."); + } + + return *asset_file_defs; } } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h index 1bfd0d5113f955..357c5b0efe52d7 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h @@ -19,13 +19,18 @@ limitations under the License. #include #include +#include #include +#include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" @@ -56,8 +61,20 @@ struct ExportOptions { std::string debug_name = "stablehlo_quant"; }; +// Creates `ExportedModel` from `module_op`. `module_op` goes through post +// process passes before an `ExportModel` is created. +// TODO: b/329206105 - Add unit tests after decomposing post processing passes. +absl::StatusOr CreateExportedModel( + const std::vector& signature_keys, + const std::unordered_set& tags, + const ::stablehlo::quantization::QuantizationConfig& quantization_config, + absl::string_view debug_name_prefix, + const absl::flat_hash_map& function_aliases, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND, ModuleOp module_op); + // Factory function for `ExportedModel`. -[[nodiscard]] tensorflow::quantization::ExportedModel CreateExportedModel( +[[nodiscard]] tensorflow::quantization::ExportedModel +CreateExportedModelFromGraphDef( tensorflow::GraphDef&& graph_def, absl::string_view init_node_name, absl::string_view checkpoint_dir, std::optional saver_def, @@ -111,6 +128,15 @@ ConvertMlirModuleToExportedModel( const absl::flat_hash_map& function_aliases, const std::vector& asset_file_defs); +// Sets up and runs the passes for exporting `module_op`. The behavior of the +// exporting passes is controlled by `export_opts`. Returns `AssetFileDef`s that +// associate the input arguments of @main and the asset file names. Asset file +// names will be used to feed the corresponding tensors during initialization +// upon model loading. +// TODO: b/329206105 - Add unit tests after decomposing post processing passes. +absl::StatusOr> RunExportPasses( + const ExportOptions& export_opts, MLIRContext& ctx, ModuleOp module_op); + } // namespace mlir::quant::stablehlo #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_SAVED_MODEL_EXPORT_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export_test.cc index e250f5314726f7..7e55644c38f886 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export_test.cc @@ -57,10 +57,10 @@ TEST(CreateExportedModelTest, CreateExportedModelBasicFieldsSet) { ASSERT_TRUE( TextFormat::ParseFromString(R"pb(node { name: "foo" })pb", &graph_def)); - const ExportedModel exported_model = - CreateExportedModel(std::move(graph_def), "init_node_name", - "checkpoint_dir", /*saver_def=*/std::nullopt, - /*function_aliases=*/{}, /*asset_file_defs=*/{}); + const ExportedModel exported_model = CreateExportedModelFromGraphDef( + std::move(graph_def), "init_node_name", "checkpoint_dir", + /*saver_def=*/std::nullopt, + /*function_aliases=*/{}, /*asset_file_defs=*/{}); ASSERT_THAT(exported_model.graph_def().node(), SizeIs(1)); EXPECT_THAT(exported_model.graph_def().node()[0].name(), StrEq("foo")); @@ -72,7 +72,7 @@ TEST(CreateExportedModelTest, CreateExportedModelBasicFieldsSet) { } TEST(CreateExportedModelTest, CreateExportedModelWithAddedFunctionAliases) { - const ExportedModel exported_model = CreateExportedModel( + const ExportedModel exported_model = CreateExportedModelFromGraphDef( GraphDef(), /*init_node_name=*/"", /*checkpoint_dir=*/"", /*saver_def=*/std::nullopt, /*function_aliases=*/{{"func1", "alias1"}, {"func2", "alias2"}}, @@ -93,7 +93,7 @@ TEST(CreateExportedModelTest, CreateExportedModelWithAddedAssetFileDefs) { ASSERT_TRUE( TextFormat::ParseFromString(R"pb(filename: "fname2")pb", &asset2)); - const ExportedModel exported_model = CreateExportedModel( + const ExportedModel exported_model = CreateExportedModelFromGraphDef( GraphDef(), /*init_node_name=*/"", /*checkpoint_dir=*/"", /*saver_def=*/std::nullopt, /*function_aliases=*/{}, /*asset_file_defs=*/{asset1, asset2}); @@ -107,7 +107,7 @@ TEST(CreateExportedModelTest, CreateExportedModelWithAddedSaverDef) { ASSERT_TRUE(TextFormat::ParseFromString( R"pb(filename_tensor_name: "my_file")pb", &saver_def)); - const ExportedModel exported_model = CreateExportedModel( + const ExportedModel exported_model = CreateExportedModelFromGraphDef( GraphDef(), /*init_node_name=*/"", /*checkpoint_dir=*/"", saver_def, /*function_aliases=*/{}, /*asset_file_defs=*/{}); EXPECT_THAT(exported_model.saver_def().filename_tensor_name(), "my_file"); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.cc index 9c03ee6e21f4b5..a223a0b03f58a4 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.cc @@ -14,23 +14,72 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h" +#include #include #include +#include +#include +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace mlir::quant::stablehlo { +using ::stablehlo::quantization::QuantizationConfig; +using ::tensorflow::MLIRImportOptions; +using ::tensorflow::SavedModelBundle; +using ::tensorflow::SavedModelSignatureDefsToMlirImport; +using ::tensorflow::quantization::PreprocessAndFreezeGraph; + +absl::StatusOr SavedModelToMlirModuleOp( + const absl::string_view saved_model_path, + const std::unordered_set& tags, + const std::vector& signature_keys, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND) { + MLIRImportOptions import_options; + import_options.upgrade_legacy = true; + import_options.lift_variables = false; + import_options.include_variables_in_initializers = true; + + auto bundle = std::make_unique(); + + // Copy to eliminate the `const` qualifier so that `absl::MakeSpan` can be + // called on it. + std::vector exported_names = signature_keys; + absl::StatusOr> module_op = + SavedModelSignatureDefsToMlirImport(saved_model_path, tags, + absl::MakeSpan(exported_names), &ctx, + import_options, &bundle); + if (!module_op.status().ok()) { + return absl::InternalError(absl::StrCat("Failed to import SavedModel: ", + module_op.status().ToString())); + } + + return std::make_pair(module_op->release(), std::move(bundle)); +} + absl::StatusOr> GetFunctionAliases(absl::string_view saved_model_path, const std::unordered_set& tags) { @@ -70,4 +119,35 @@ void UpdateFunctionAliases( }); } +absl::StatusOr ImportSavedModel( + const absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + const QuantizationConfig& quantization_config, + const absl::string_view mlir_dump_file_prefix, + absl::flat_hash_map& function_aliases, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND) { + TF_ASSIGN_OR_RETURN( + ImportedMlirModuleOp imported_module, + SavedModelToMlirModuleOp(saved_model_path, tags, signature_keys, ctx)); + auto [module_op, saved_model_bundle] = std::move(imported_module); + + UpdateFunctionAliases(function_aliases, module_op); + + // Collect the names of the functions that have aliases so that they may not + // be inlined. + absl::flat_hash_set aliased_function_names; + absl::c_for_each(function_aliases, [&](const auto& aliases) { + return aliased_function_names.insert(aliases.first); + }); + + TF_RETURN_IF_ERROR(PreprocessAndFreezeGraph( + mlir_dump_file_prefix, /*is_inliner_run=*/true, + /*noinline_functions=*/aliased_function_names, module_op, &ctx, + saved_model_bundle == nullptr ? nullptr + : saved_model_bundle->GetSession(), + /*run_tf_to_stablehlo=*/true, /*deserialize_xla_call_module=*/false)); + return module_op; +} + } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h index 2c20224cf24ed2..631d2e714900aa 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h @@ -19,15 +19,40 @@ limitations under the License. #include #include +#include +#include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" namespace mlir::quant::stablehlo { +// Represents a pair of `mlir::ModuleOp` and `tensorflow::SavedModelBundle`. The +// SavedModelBundle complements the imported ModuleOp by providing access to +// `tensorflow::Session` which may be useful when reading values from resources +// (e.g. `TF::VarHandleOp`s). +using ImportedMlirModuleOp = + std::pair>; + +// Loads a SavedModel at `saved_model_path` and converts it to `mlir::ModuleOp`. +// +// `tags` identify the `tensorflow::MetaGraphDef` to load from the SavedModel. +// Similarly, `signature_keys` identify the functions (`SignatureDef`s) to load +// within the `MetaGraphDef`. `ctx` is the `MLIRContext`, which should outlive +// the returned `ModuleOp`, thus marked with the lifetime bound attribute. +// TODO: b/329206105 - Add unit tests after decomposing preprocessing passes. +absl::StatusOr SavedModelToMlirModuleOp( + absl::string_view saved_model_path, + const std::unordered_set& tags, + const std::vector& signature_keys, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND); + // Gets the function aliases from the SavedModel. absl::StatusOr> GetFunctionAliases(absl::string_view saved_model_path, @@ -44,6 +69,18 @@ void UpdateFunctionAliases( absl::flat_hash_map& function_aliases, ModuleOp module_op); +// Loads a SavedModel to `mlir::ModuleOp` and performs preprocesses including +// shape inference and graph freezing. +// TODO: b/329206105 - Add unit tests after decomposing preprocessing passes. +absl::StatusOr ImportSavedModel( + absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + const ::stablehlo::quantization::QuantizationConfig& quantization_config, + absl::string_view mlir_dump_file_prefix, + absl::flat_hash_map& function_aliases, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND); + } // namespace mlir::quant::stablehlo #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_SAVED_MODEL_IMPORT_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc index eaafdf1770f7f9..015ab7605a05b7 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc @@ -15,200 +15,44 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h" #include -#include #include #include #include #include -#include "absl/algorithm/container.h" -#include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/context.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/saver.pb.h" -#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace mlir::quant::stablehlo { -namespace { using ::stablehlo::quantization::QuantizationConfig; -using ::stablehlo::quantization::io::GetLocalTmpFileName; -using ::tensorflow::AssetFileDef; -using ::tensorflow::MLIRImportOptions; -using ::tensorflow::SavedModelBundle; -using ::tensorflow::SavedModelSignatureDefsToMlirImport; using ::tensorflow::SignatureDef; using ::tensorflow::quantization::ExportedModel; -using ::tensorflow::quantization::PreprocessAndFreezeGraph; using ::tensorflow::quantization::PyFunctionLibrary; -using ::tensorflow::quantization::RunPasses; -using ::tensorflow::quantization::UnfreezeConstantsAndSaveVariables; - -// Sets up and runs the passes for exporting `module_op`. The behavior of the -// exporting passes is controlled by `export_opts`. Returns `AssetFileDef`s that -// associate the input arguments of @main and the asset file names. Asset file -// names will be used to feed the corresponding tensors during initialization -// upon model loading. -absl::StatusOr> RunExportPasses( - const ExportOptions& export_opts, MLIRContext& ctx, ModuleOp module_op) { - if (export_opts.unfreeze_constants) { - TF_RETURN_IF_ERROR(UnfreezeConstantsAndSaveVariables( - export_opts.checkpoint_dir, ctx, module_op)); - LOG(INFO) << "Unfrozen constants and saved variables to checkpoint file: " - << export_opts.checkpoint_dir; - } - - if (absl::Status pass_run_status = RunPasses( - /*name=*/ - export_opts.debug_name, - /*add_passes_func=*/ - [dup_constants = export_opts.duplicate_shape_determining_constants]( - PassManager& pm) { AddExportPasses(pm, dup_constants); }, - ctx, module_op); - !pass_run_status.ok()) { - return pass_run_status; - } - - FailureOr> asset_file_defs = - quant::ConvertAssetArgs(module_op); - if (failed(asset_file_defs)) { - return absl::InternalError("Failed to convert asset args."); - } - - return *asset_file_defs; -} - -// Represents a pair of `mlir::ModuleOp` and `tensorflow::SavedModelBundle`. The -// SavedModelBundle complements the imported ModuleOp by providing access to -// `tensorflow::Session` which may be useful when reading values from resources -// (e.g. `TF::VarHandleOp`s). -using ImportedMlirModuleOp = - std::pair>; - -// Loads a SavedModel at `saved_model_path` and converts it to `mlir::ModuleOp`. -// -// `tags` identify the `tensorflow::MetaGraphDef` to load from the SavedModel. -// Similarly, `signature_keys` identify the functions (`SignatureDef`s) to load -// within the `MetaGraphDef`. `ctx` is the `MLIRContext`, which should outlive -// the returned `ModuleOp`, thus marked with the lifetime bound attribute. -absl::StatusOr SavedModelToMlirModuleOp( - const absl::string_view saved_model_path, - const std::unordered_set& tags, - const std::vector& signature_keys, - MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND) { - MLIRImportOptions import_options; - import_options.upgrade_legacy = true; - import_options.lift_variables = false; - import_options.include_variables_in_initializers = true; - - auto bundle = std::make_unique(); - - // Copy to eliminate the `const` qualifier so that `absl::MakeSpan` can be - // called on it. - std::vector exported_names = signature_keys; - absl::StatusOr> module_op = - SavedModelSignatureDefsToMlirImport(saved_model_path, tags, - absl::MakeSpan(exported_names), &ctx, - import_options, &bundle); - if (!module_op.status().ok()) { - return absl::InternalError(absl::StrCat("Failed to import SavedModel: ", - module_op.status().ToString())); - } - - return std::make_pair(module_op->release(), std::move(bundle)); -} - -absl::StatusOr ImportSavedModel( - const absl::string_view saved_model_path, - const std::vector& signature_keys, - const std::unordered_set& tags, - const QuantizationConfig& quantization_config, - absl::flat_hash_map& function_aliases, - MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND) { - TF_ASSIGN_OR_RETURN( - ImportedMlirModuleOp imported_module, - SavedModelToMlirModuleOp(saved_model_path, tags, signature_keys, ctx)); - auto [module_op, saved_model_bundle] = std::move(imported_module); - - UpdateFunctionAliases(function_aliases, module_op); - - // Collect the names of the functions that have aliases so that they may not - // be inlined. - absl::flat_hash_set aliased_function_names; - absl::c_for_each(function_aliases, [&](const auto& aliases) { - return aliased_function_names.insert(aliases.first); - }); - - TF_RETURN_IF_ERROR(PreprocessAndFreezeGraph( - /*mlir_dump_file_prefix=*/PreCalibrationComponent::kName, - /*is_inliner_run=*/true, /*noinline_functions=*/aliased_function_names, - module_op, &ctx, - saved_model_bundle == nullptr ? nullptr - : saved_model_bundle->GetSession(), - /*run_tf_to_stablehlo=*/true, /*deserialize_xla_call_module=*/false)); - return module_op; -} - -absl::StatusOr CreateExportedModel( - const std::vector& signature_keys, - const std::unordered_set& tags, - const QuantizationConfig& quantization_config, - absl::flat_hash_map& function_aliases, - MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND, ModuleOp module_op) { - TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTmpFileName()); - const ExportOptions export_opts = { - /*duplicate_shape_determining_constants=*/true, - /*unfreeze_constants=*/false, checkpoint_dir, - /*debug_name=*/ - absl::StrCat(PostCalibrationComponent::kName, kExportStepSuffix)}; - - TF_ASSIGN_OR_RETURN(const SmallVector asset_file_defs, - RunExportPasses(export_opts, ctx, module_op)); - - UpdateFunctionAliases(function_aliases, module_op); - - return ConvertMlirModuleToExportedModel( - module_op, checkpoint_dir, function_aliases, - {asset_file_defs.begin(), asset_file_defs.end()}); -} - -} // namespace StaticRangePtqComponent::StaticRangePtqComponent( absl::Nonnull ctx, @@ -243,17 +87,13 @@ absl::StatusOr StaticRangePtqComponent::Run( absl::Status QuantizeStaticRangePtq( const absl::string_view src_saved_model_path, const absl::string_view dst_saved_model_path, - QuantizationConfig quantization_config, + const QuantizationConfig& quantization_config, const std::vector& signature_keys, const absl::flat_hash_map& signature_def_map, const PyFunctionLibrary& py_function_library) { std::unordered_set tags; tags.insert(quantization_config.tf_saved_model().tags().begin(), quantization_config.tf_saved_model().tags().end()); - if (!quantization_config.has_calibration_options()) { - *quantization_config.mutable_calibration_options() = - GetDefaultCalibrationOptions(); - } std::unique_ptr ctx = CreateMlirContextForQuantization(); @@ -267,7 +107,8 @@ absl::Status QuantizeStaticRangePtq( TF_ASSIGN_OR_RETURN( ModuleOp module_op, ImportSavedModel(src_saved_model_path, signature_keys, tags, - quantization_config, *function_aliases, *ctx)); + quantization_config, PreCalibrationComponent::kName, + *function_aliases, *ctx)); StaticRangePtqComponent static_range_ptq_component( ctx.get(), &py_function_library, src_saved_model_path, signature_keys, @@ -278,7 +119,8 @@ absl::Status QuantizeStaticRangePtq( TF_ASSIGN_OR_RETURN( const ExportedModel post_calibrated_exported_model, CreateExportedModel(signature_keys, tags, quantization_config, - *function_aliases, *ctx, module_op)); + PostCalibrationComponent::kName, *function_aliases, + *ctx, module_op)); // Remove the `tpu` tag for exporting because the output quantized model is // essentially a CPU model. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h index e5056418bbae55..69bd9da6733c0c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h @@ -37,17 +37,6 @@ limitations under the License. namespace mlir::quant::stablehlo { -using ::stablehlo::quantization::CalibrationOptions; - -// Create default configuration for the calibration step, which is the min/max -// calibration method. -inline CalibrationOptions GetDefaultCalibrationOptions() { - CalibrationOptions options{}; - options.set_calibration_method( - CalibrationOptions::CALIBRATION_METHOD_MIN_MAX); - return options; -} - // Component for static-range post-training quantization (PTQ). // TODO: b/320607042 - Add tests in python level. class StaticRangePtqComponent : public Component { @@ -102,7 +91,7 @@ class StaticRangePtqComponent : public Component { absl::Status QuantizeStaticRangePtq( absl::string_view src_saved_model_path, absl::string_view dst_saved_model_path, - ::stablehlo::quantization::QuantizationConfig quantization_config, + const ::stablehlo::quantization::QuantizationConfig& quantization_config, const std::vector& signature_keys, const absl::flat_hash_map& signature_def_map, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc new file mode 100644 index 00000000000000..bbd9a9c25620bd --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc @@ -0,0 +1,114 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h" + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/die_if_null.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/context.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace mlir::quant::stablehlo { + +using ::stablehlo::quantization::QuantizationConfig; +using ::tensorflow::SignatureDef; +using ::tensorflow::quantization::ExportedModel; +using ::tensorflow::quantization::PyFunctionLibrary; +using ::tensorflow::quantization::RunPasses; + +WeightOnlyPtqComponent::WeightOnlyPtqComponent(absl::Nonnull ctx) + : ctx_(ABSL_DIE_IF_NULL(ctx)) {} // Crash OK + +absl::StatusOr WeightOnlyPtqComponent::Run( + ModuleOp module_op, const QuantizationConfig& config) { + TF_RETURN_IF_ERROR(RunPasses( + kName, /*add_passes_func=*/ + [&config](PassManager& pm) { + AddWeightOnlyQuantizationPasses(pm, config.specs(), + config.pipeline_config(), + config.debugger_config()); + }, + *ctx_, module_op)); + return module_op; +} + +absl::Status QuantizeWeightOnlyPtq( + const absl::string_view src_saved_model_path, + const absl::string_view dst_saved_model_path, + QuantizationConfig quantization_config, + const std::vector& signature_keys, + const absl::flat_hash_map& signature_def_map, + const PyFunctionLibrary& py_function_library) { + std::unordered_set tags; + tags.insert(quantization_config.tf_saved_model().tags().begin(), + quantization_config.tf_saved_model().tags().end()); + + std::unique_ptr ctx = CreateMlirContextForQuantization(); + + absl::StatusOr> + function_aliases = GetFunctionAliases(src_saved_model_path, tags); + if (!function_aliases.ok()) { + return absl::InternalError(absl::StrCat( + "Failed to get function alias: ", function_aliases.status().message())); + } + + TF_ASSIGN_OR_RETURN( + ModuleOp module_op, + ImportSavedModel(src_saved_model_path, signature_keys, tags, + quantization_config, WeightOnlyPtqComponent::kName, + *function_aliases, *ctx)); + + WeightOnlyPtqComponent weight_only_ptq_component(ctx.get()); + TF_ASSIGN_OR_RETURN( + module_op, weight_only_ptq_component.Run(module_op, quantization_config)); + + TF_ASSIGN_OR_RETURN( + const ExportedModel post_calibrated_exported_model, + CreateExportedModel(signature_keys, tags, quantization_config, + WeightOnlyPtqComponent::kName, *function_aliases, + *ctx, module_op)); + + // Remove the `tpu` tag for exporting because the output quantized model is + // essentially a CPU model. + tags.erase("tpu"); + + py_function_library.SaveExportedModel( + dst_saved_model_path, post_calibrated_exported_model, + src_saved_model_path, tags, signature_def_map); + + return absl::OkStatus(); +} + +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h new file mode 100644 index 00000000000000..bf23e93246c700 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h @@ -0,0 +1,80 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_WEIGHT_ONLY_PTQ_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_WEIGHT_ONLY_PTQ_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace mlir::quant::stablehlo { + +// Performs int8 weight-only quantization on dot_general ops. +// +// The resulting `ModuleOp` contains quantized StableHLO ops serialized in +// `TF::XlaCallModuleOp`s. They are quantized using the weight constants, not +// relying on calibration. +class WeightOnlyPtqComponent : public Component { + public: + // Used for debugging purposes. + static constexpr absl::string_view kName = "quant_ptq_weight_only"; + + explicit WeightOnlyPtqComponent(absl::Nonnull ctx); + + absl::StatusOr Run( + ModuleOp module_op, + const ::stablehlo::quantization::QuantizationConfig& config) override; + + private: + absl::Nonnull ctx_; +}; + +// Runs weight-only quantization on a SavedModel at +// `src_saved_model_path` and saves the resulting model to +// `dst_saved_model_path`. +// +// `quantization_config` configures the quantization behavior for the +// weight-only quantization. +// +// `signature_keys` specify the signatures that correspond to functions to be +// quantized. `signature_def_map` connects the signature keys to +// `SignatureDef`s. +// +// Returns a non-OK status when the quantization is not successful. +// LINT.IfChange +absl::Status QuantizeWeightOnlyPtq( + absl::string_view src_saved_model_path, + absl::string_view dst_saved_model_path, + ::stablehlo::quantization::QuantizationConfig quantization_config, + const std::vector& signature_keys, + const absl::flat_hash_map& + signature_def_map, + const tensorflow::quantization::PyFunctionLibrary& py_function_library); +// LINT.ThenChange(../python/pywrap_quantization.cc:weight_only_ptq) + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_WEIGHT_ONLY_PTQ_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD index 0f8bb04d796a6e..35584857f5761f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD @@ -41,12 +41,14 @@ tf_cc_test( ":stablehlo_op_quant_spec", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common:test_base", + "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:test", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@stablehlo//:stablehlo_ops", ], ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc index a78a1feec9077e..c78ee607993385 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc @@ -47,6 +47,7 @@ namespace { using ::mlir::stablehlo::DotGeneralOp; using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::StaticRangePtq; // Whether it represents a lifted function (i.e. `op` is the corresponding // `XlaCallModuleOp`) that is explicitly marked `NoQuantization`. @@ -61,6 +62,31 @@ bool IsDenylistedLiftedFunction(Operation* op) { return false; } +// Populates `spec.coeff_op_quant_dim` according to `xla_call_module_op`'s +// `_quantization_method` attribute. If there is an input `QuantizedType` with +// `dimension_specs` set, which represents the quantization dimension for the +// input, then the corresponding operand index -> quantization dimension mapping +// is set for `spec`. +// TODO: b/323478683 - Duplicate tracking of config will be eliminated. +// `OpQuantSpec` will be deprecated and `Method` will be used instead. +void PopulateCoeffOpQuantDimIfPerChannelQuantized( + TF::XlaCallModuleOp xla_call_module_op, OpQuantSpec& spec) { + absl::StatusOr method = GetQuantizationMethod(xla_call_module_op); + if (method.ok() && method->has_static_range_ptq()) { + // TODO: b/331145946 - Use `Method` accessors. + const StaticRangePtq& static_range_ptq_spec = method->static_range_ptq(); + // Look for quantized dimension specs for each quantized type and + // populate `coeff_op_quant_dim`. + for (const auto& [operand_idx, quantized_type] : + static_range_ptq_spec.input_quantized_types()) { + if (quantized_type.has_dimension_specs()) { + spec.coeff_op_quant_dim[operand_idx] = + quantized_type.dimension_specs().dimension(); + } + } + } +} + } // namespace std::unique_ptr GetStableHloOpQuantSpec(Operation* op) { @@ -72,8 +98,12 @@ std::unique_ptr GetStableHloOpQuantSpec(Operation* op) { if (!function_name.starts_with("composite_")) { return spec; } + if (function_name.contains("conv")) { - spec->coeff_op_quant_dim[1] = 3; + // Looks up `Method` to see if it should be per-channel quantized and + // populates the spec accordingly. + PopulateCoeffOpQuantDimIfPerChannelQuantized(call_op, *spec); + if (function_name.contains("with_bias")) { spec->biases_params[2] = {{0, 1}, quant::GetUniformQuantizedTypeForBias}; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec_test.cc index 39baea749992d1..b3ba4818284498 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec_test.cc @@ -15,14 +15,18 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" +#include + #include #include #include "absl/strings/string_view.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/common/test_base.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/test.h" @@ -30,7 +34,10 @@ limitations under the License. namespace mlir::quant::stablehlo { namespace { +using ::testing::IsEmpty; using ::testing::NotNull; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; using IsOpQuantizableStableHloTest = ::mlir::quant::QuantizationTestBase; @@ -208,5 +215,74 @@ TEST_F(IsOpQuantizableStableHloTest, DenylistedXlaCallModuleOpNotQuantizable) { EXPECT_FALSE(IsOpQuantizableStableHlo(xla_call_module_op)); } +using GetStableHloOpQuantSpecTest = ::mlir::quant::QuantizationTestBase; + +TEST_F(GetStableHloOpQuantSpecTest, + EmptyCoeffOpQuantDimForPerTensorQuantizedConvolution) { + // A `TF::XlaCallModuleOp` with `_quantization_method = "static_range_ptq + // {}"`, representing a per-tensor static-range PTQ quantization. + constexpr absl::string_view + kXlaCallModuleOpWithPerTensorQuantizedConvolution = R"mlir( + func.func @main(%arg0: tensor<1x1x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<1x1x4xf32> { + %0 = "tf.XlaCallModule"(%arg0, %arg1) <{Sout = [#tf_type.shape<1x1x4>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> + { + _entry_function = @composite_conv_fn_1, + _original_entry_function = "composite_conv_fn_1", + _quantization_method = "static_range_ptq {}", + _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, + _tfl_quant_trait = "fully_quantizable" + } : (tensor<1x1x3xf32>, tensor<3x4xf32>) -> tensor<1x1x4xf32> + return %0 : tensor<1x1x4xf32> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kXlaCallModuleOpWithPerTensorQuantizedConvolution); + ASSERT_TRUE(module_op); + + const FailureOr xla_call_module_op = + FindFirstOpFromMainFunc(*module_op); + ASSERT_TRUE(succeeded(xla_call_module_op)); + + const std::unique_ptr op_quant_spec = + GetStableHloOpQuantSpec(*xla_call_module_op); + ASSERT_THAT(op_quant_spec, NotNull()); + + EXPECT_THAT(op_quant_spec->coeff_op_quant_dim, IsEmpty()); +} + +TEST_F(GetStableHloOpQuantSpecTest, + EmptyCoeffOpQuantDimForPerChannelQuantizedConvolution) { + constexpr absl::string_view + kXlaCallModuleOpWithPerChannelQuantizedConvolution = R"mlir( + func.func @main(%arg0: tensor<1x1x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<1x1x4xf32> { + %0 = "tf.XlaCallModule"(%arg0, %arg1) <{Sout = [#tf_type.shape<1x1x4>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> + { + _entry_function = @composite_conv_fn_1, + _original_entry_function = "composite_conv_fn_1", + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, + _tfl_quant_trait = "fully_quantizable" + } : (tensor<1x1x3xf32>, tensor<3x4xf32>) -> tensor<1x1x4xf32> + return %0 : tensor<1x1x4xf32> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kXlaCallModuleOpWithPerChannelQuantizedConvolution); + ASSERT_TRUE(module_op); + + const FailureOr xla_call_module_op = + FindFirstOpFromMainFunc(*module_op); + ASSERT_TRUE(succeeded(xla_call_module_op)); + + const std::unique_ptr op_quant_spec = + GetStableHloOpQuantSpec(*xla_call_module_op); + ASSERT_THAT(op_quant_spec, NotNull()); + + EXPECT_THAT(op_quant_spec->coeff_op_quant_dim, + UnorderedElementsAre(Pair(1, 3))); +} + } // namespace } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc index b07b833429f8b6..33d66316870798 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc @@ -26,8 +26,7 @@ namespace mlir::quant::stablehlo { void AddQuantizationLoweringPasses(mlir::OpPassManager& pm) { // These passes are grouped together and must run in this specific order. pm.addNestedPass(CreateConvertTFQuantOpsToMHLOPass()); - pm.addNestedPass(mhlo::createChloLegalizeToHloPass( - /*legalizeBroadcasts=*/true, /*expandCompositions=*/false)); + pm.addNestedPass(mhlo::createChloLegalizeToHloPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass( mhlo::createMhloQuantLegalizeToIntPass()); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc new file mode 100644 index 00000000000000..5be09ce2ad47ef --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc @@ -0,0 +1,288 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/base/nullability.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h" + +namespace mlir::quant::stablehlo { + +#define GEN_PASS_DEF_DEFERACTIVATIONTRANSPOSEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" + +namespace { + +using ::mlir::stablehlo::AddOp; +using ::mlir::stablehlo::BroadcastInDimOp; +using ::mlir::stablehlo::MaxOp; +using ::mlir::stablehlo::TransposeOp; + +// Returns `success()` if `op` is a `TransposeOp` with permutation attribute +// equivalent to `permuation`. +LogicalResult IsTransposeOpWithPermuation(absl::Nullable op, + const ArrayRef permutation) { + auto transpose_op = dyn_cast_or_null(op); + return success(transpose_op != nullptr && transpose_op.getPermutation() == + ArrayRef(permutation)); +} + +// Convenience function to create a `TransposeOp` with a given `permutation`. +// The Location is set as `input`'s loc. +TransposeOp CreateTransposeOp(Value input, const ArrayRef permutation, + PatternRewriter& rewriter) { + return rewriter.create( + input.getLoc(), input, rewriter.getDenseI64ArrayAttr(permutation)); +} + +// Defers the transpose of the left-hand side (LHS) to the right-hand side and +// the result of a binary operation. In detail, this rewrites the +// `op(transpose(%rhs), %lhs)` to `transpose(op(%rhs, transpose(%lhs)))`. The +// LHS transpose permutation must be a NCHW->NHWC permutation. +template +void DeferRhsTransposeForBinaryOp(OpT op, PatternRewriter& rewriter) { + auto transpose_op = cast(op.getOperand(0).getDefiningOp()); + Value lhs_pre_transpose = transpose_op.getOperand(); + + // NCHW -> NHWC for the right-hand side, to match the operand's shape. + Value rhs = op.getOperand(1); + TransposeOp rhs_transpose_op = CreateTransposeOp( + /*input=*/rhs, kNchwToNhwcPermutation, rewriter); + + auto new_binary_op = + rewriter.create(op.getLoc(), lhs_pre_transpose, rhs_transpose_op); + + // NHWC -> NCHW for the output, to match the shapes of `op`'s users. + TransposeOp output_transpose_op = CreateTransposeOp( + /*input=*/new_binary_op, kNhwcToNchwPermutation, rewriter); + + rewriter.replaceAllUsesWith(op.getResult(), output_transpose_op); +} + +// "Climbs up" the `op` if `op` is a `BraodcastInDimOp` and returns the defining +// op of its operand. Returns `op` otherwise. May return `nullptr` when the +// `BroadcastInDimOp`'s operand is a block argument. +absl::Nullable SkipUpwardsOptionalBroadcastInDimOp( + absl::Nonnull op) { + if (auto broadcast_in_dim_op = dyn_cast_or_null(op); + broadcast_in_dim_op != nullptr) { + return broadcast_in_dim_op.getOperand().getDefiningOp(); + } + return op; +} + +class DeferActivationTransposeForAddOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(AddOp op) const override { + // Only supports the case for 2D convolution. + const Value lhs = op.getOperand(0); + if (!HasRankOf(lhs, /*rank=*/4)) return failure(); + + const Value rhs = op.getOperand(1); + Operation* rhs_op = rhs.getDefiningOp(); + if (rhs_op == nullptr) return failure(); + + // Ignore the optional `BroadcastInDimOp` in between the constant and RHS. + rhs_op = SkipUpwardsOptionalBroadcastInDimOp(rhs_op); + + if (rhs_op == nullptr || !rhs_op->hasTrait()) { + return failure(); + } + + // Match LHS permutation that converts: NHWC -> NCHW. + return IsTransposeOpWithPermuation(lhs.getDefiningOp(), + kNhwcToNchwPermutation); + } + + void rewrite(AddOp op, PatternRewriter& rewriter) const override { + DeferRhsTransposeForBinaryOp(op, rewriter); + } +}; + +// Rewrites the `reduce_window(transpose(%activation), %init_value)` patterns to +// `transpose(reduce_window(%activation), %init_value)`, deferring the transpose +// to the result. The reduce function should be equivalent to +// `stablehlo.maximum`, representing max pooling. +class DeferActivationTransposeForMaxPoolReduceWindowOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(mlir::stablehlo::ReduceWindowOp op) const override { + if (failed(MatchMaxPoolReduceWindowOp(op))) return failure(); + + // Match only when the lhs is connected to a transpose. + // Only supports the case commonly appearing for 2D convolutions. + Value lhs = op.getOperand(0); + if (!HasRankOf(lhs, /*rank=*/4)) return failure(); + + // Match input permutation that converts: NHWC -> NCHW. + return IsTransposeOpWithPermuation(lhs.getDefiningOp(), + kNhwcToNchwPermutation); + } + + // Pushes the transpose op at the input to the result. + void rewrite(mlir::stablehlo::ReduceWindowOp op, + PatternRewriter& rewriter) const override { + auto transpose_op = cast(op.getOperand(0).getDefiningOp()); + + const auto result_type = op.getResult(0).getType().cast(); + const SmallVector new_result_shape = + Permute(result_type.getShape(), kNchwToNhwcPermutation); + + const TensorType new_result_type = + result_type.cloneWith(new_result_shape, result_type.getElementType()); + + // Create a new `stablehlo.reduce_window` with all relevant attributes + // permutated to match the new operand & result type. + auto new_reduce_window_op = + rewriter.create( + op.getLoc(), new_result_type, transpose_op.getOperand(), + /*init_value=*/op.getOperand(1), + /*window_dimensions=*/ + PermuteI64ArrayAttr(rewriter, op.getWindowDimensionsAttr(), + kNchwToNhwcPermutation), + /*window_strides=*/ + PermuteI64ArrayAttr(rewriter, op.getWindowStridesAttr(), + kNchwToNhwcPermutation), + /*base_dilations=*/ + PermuteI64ArrayAttr(rewriter, op.getBaseDilationsAttr(), + kNchwToNhwcPermutation), + /*window_dilations=*/ + PermuteI64ArrayAttr(rewriter, op.getWindowDilationsAttr(), + kNchwToNhwcPermutation), + /*padding=*/DenseIntElementsAttr(nullptr)); + + // Clone the reduce body. It is not affected by the permutation. + IRMapping mapping; + op.getBody().cloneInto(&new_reduce_window_op.getBody(), mapping); + + // Introduce a transpose to the result to match the shapes of `op`'s uses. + TransposeOp result_transpose_op = CreateTransposeOp( + /*input=*/new_reduce_window_op.getResult(0), kNhwcToNchwPermutation, + rewriter); + + rewriter.replaceAllUsesWith(op.getResult(0), result_transpose_op); + } + + private: + // Permutes `array_attr` with `permutation`. The number of elements in + // `array_attr` and `permutation` must be equal. Returns a null attribute + // if `array_attr` is null. + DenseI64ArrayAttr PermuteI64ArrayAttr( + PatternRewriter& rewriter, const DenseI64ArrayAttr array_attr, + const ArrayRef permutation) const { + if (array_attr == nullptr) return DenseI64ArrayAttr(nullptr); + + return rewriter.getDenseI64ArrayAttr( + Permute(array_attr, permutation)); + } + + LogicalResult MatchMaxPoolReduceWindowOp( + mlir::stablehlo::ReduceWindowOp op) const { + // TODO: b/321099943 - Support explicit padding. + if (HasPadding(op)) return failure(); + + // Check that the reduce-window body is a max operation. + return success(IsMaxFunction(op.getBody().front())); + } + + // Whether `block` semantically corresponds to a `stablehlo.maximum` op. + bool IsMaxFunction(Block& block) const { + if (block.getNumArguments() != 2) return false; + + auto return_op = cast(block.getTerminator()); + if (return_op.getNumOperands() != 1) return false; + + auto max_op = dyn_cast_or_null( + return_op.getOperands().front().getDefiningOp()); + if (!max_op) return false; + + return (max_op.getLhs() == block.getArgument(0)) && + (max_op.getRhs() == block.getArgument(1)); + } + + // Whether `op` has the `padding` attribute (which is optional). + bool HasPadding(mlir::stablehlo::ReduceWindowOp op) const { + return op.getPadding() != std::nullopt; + } +}; + +// Rewrites `maximum(transpose(%rhs), %lhs)` patterns to +// `transpose(maximum(%rhs, transpose(%lhs)))`. +class DeferActivationTransposeForMaxOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(MaxOp op) const override { + Value input = op.getOperand(0); + if (!HasRankOf(input, /*rank=*/4)) return failure(); + + const Value max_value = op.getOperand(1); + Operation* max_value_op = max_value.getDefiningOp(); + if (max_value_op == nullptr || + !max_value_op->hasTrait()) { + return failure(); + } + + return IsTransposeOpWithPermuation(input.getDefiningOp(), + kNhwcToNchwPermutation); + } + + void rewrite(MaxOp op, PatternRewriter& rewriter) const override { + DeferRhsTransposeForBinaryOp(op, rewriter); + } +}; + +} // namespace + +class DeferActivationTransposePass + : public impl::DeferActivationTransposePassBase< + DeferActivationTransposePass> { + private: + void runOnOperation() override; +}; + +void DeferActivationTransposePass::runOnOperation() { + func::FuncOp func_op = getOperation(); + MLIRContext& ctx = getContext(); + + RewritePatternSet patterns(&ctx); + patterns.add(&ctx); + if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns)))) { + func_op->emitWarning() << "Failed to converge patterns: " << getArgument(); + } +} + +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose_pass.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc similarity index 92% rename from tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose_pass.cc rename to tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc index 52a101b997ad89..051745c0d6792b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose_pass.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h" namespace mlir::quant::stablehlo { @@ -53,17 +54,6 @@ int64_t GetContiguousOffset(const ArrayRef indices, return contiguous_offset; } -// Permutes `values` with `permutation`. Returns the permuted values. Sizes of -// `values` and `permutation` must be equal. -SmallVector Permute(const ArrayRef values, - const ArrayRef permutation) { - SmallVector permuted_values(/*Size=*/values.size(), /*Value=*/0); - for (auto [i, permutation_idx] : llvm::enumerate(permutation)) { - permuted_values[i] = values[permutation_idx]; - } - return permuted_values; -} - // Performs transposition of a tensor represented as a contiguous element array. // Assumes row-major order. The shape of the input tensor and the desired // permutation is registered during construction, and calling `TransposeValues` @@ -74,7 +64,7 @@ class DenseElementsTransposer { const ArrayRef permutation) : rank_(original_shape.size()), original_shape_(original_shape), - target_shape_(Permute(original_shape, permutation)), + target_shape_(Permute(original_shape, permutation)), permutation_(permutation) {} // Transposes `values` with the permutation. Returns the transposed values. @@ -102,7 +92,7 @@ class DenseElementsTransposer { GetContiguousOffset(current_indices, original_shape_); const SmallVector target_indices = - Permute(current_indices, permutation_); + Permute(current_indices, permutation_); const int64_t target_index = GetContiguousOffset(target_indices, target_shape_); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize_hybrid.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc similarity index 90% rename from tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize_hybrid.cc rename to tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc index 77a389398270a9..9fb1e9e985d15e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize_hybrid.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc @@ -47,20 +47,21 @@ limitations under the License. namespace mlir::quant::stablehlo { -#define GEN_PASS_DEF_PREPAREQUANTIZEHYBRIDPASS +#define GEN_PASS_DEF_INSERTWEIGHTPARAMPASS #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" namespace { -// Prepare hybrid quantization for weight-only quantization and dynamic range -// quantization of `stablehlo.convolution` and `stablehlo.dot_general`. -class PrepareQuantizeHybridPass - : public impl::PrepareQuantizeHybridPassBase { +// Inserts quantization parameters of weights for weight-only quantization and +// dynamic range quantization of `stablehlo.convolution` and +// `stablehlo.dot_general`. +class InsertWeightParamPass + : public impl::InsertWeightParamPassBase { public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareQuantizeHybridPass) + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertWeightParamPass) - using impl::PrepareQuantizeHybridPassBase< - PrepareQuantizeHybridPass>::PrepareQuantizeHybridPassBase; + using impl::InsertWeightParamPassBase< + InsertWeightParamPass>::InsertWeightParamPassBase; private: void runOnOperation() override; @@ -96,7 +97,8 @@ class InsertWeightParamPattern return false; } Operation* user = operand.getOwner(); - if (auto call_op = cast(user)) { + if (isa(user)) { + auto call_op = cast(user); const StringRef function_name = GetEntryFunctionName(call_op); const bool is_conv_or_dot = function_name.contains("conv") || function_name.contains("dot_general"); @@ -134,7 +136,7 @@ class InsertWeightParamPattern } }; -void PrepareQuantizeHybridPass::runOnOperation() { +void InsertWeightParamPass::runOnOperation() { func::FuncOp func = getOperation(); MLIRContext* context = func.getContext(); RewritePatternSet patterns(context); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc index 13fe29fe787324..a4bf42ec6f8eba 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -76,6 +77,12 @@ bool FloatValueEquals(const Attribute& attr, const double value) { }); } +inline void TrimTrailingWhitespaces(std::string& str) { + while (!str.empty() && str.back() == ' ') { + str.pop_back(); + } +} + // Lifts quantizable units as separate functions, thereby identifying the // boundaries of quantizable subgraphs. `QuantizationSpecs` influences how // quantizable units are lifted. @@ -146,16 +153,22 @@ class FunctionNameMatcher { std::unique_ptr match_regex_; // NOLINT }; -// Converts `Method` to text proto representation. All newline characters are -// removed. +// Converts `Method` to a single-line textproto representation. Returns +// `failure()` when converting to textproto failed. FailureOr QuantizationMethodToTextProto(const Method& method) { + TextFormat::Printer printer; + printer.SetSingleLineMode(true); + std::string method_txtpb; - if (!TextFormat::PrintToString(method, &method_txtpb)) { + if (!printer.PrintToString(method, &method_txtpb)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to convert Method to textproto\n."); return failure(); } - // Remove newlines. - absl::StrReplaceAll({{"\n", ""}}, &method_txtpb); + // Single line mode might have an extra space at the end, due to the internal + // details of `Printer`. + TrimTrailingWhitespaces(method_txtpb); + return method_txtpb; } @@ -168,11 +181,6 @@ LogicalResult ApplyQuantizationSpec(const QuantizationSpec& spec, if (!main_func) return failure(); const Method& quantization_method = spec.method(); - if (!quantization_method.has_no_quantization()) { - module_op->emitError() << "Unsupported quantization method: " - << quantization_method.DebugString() << "\n"; - return failure(); - } FailureOr quantization_method_txtpb = QuantizationMethodToTextProto(quantization_method); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td index 07598356cce7d3..eaa8a9092f41f2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td @@ -67,3 +67,11 @@ def LiftGather : Pat< (NamedAttr<"slice_sizes"> $slice_sizes), (NamedAttr<"indices_are_sorted"> (DefaultOrNullAttr $indices_are_sorted)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $operand)], [], (addBenefit 1)>; + +def LiftAdd : Pat< + (StableHLO_AddOp:$res + $lhs, $rhs), + (LiftAsTFXlaCallModule<"composite_add_fn"> + (ArgumentList $lhs, $rhs), + (ResultList $res)), + [(IsNotInLiftedFunc $res), (IsNotInStableHloOpRegion $res)], [], (addBenefit 1)>; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc index 5ba80df30a9f2d..521f701598fb0a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc @@ -28,6 +28,7 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h" namespace mlir::quant::stablehlo { @@ -72,20 +73,20 @@ class RewriteNchwConvolutionToNhwc // Transpose the input tensor: [b, f, 0, 1] => [b, 0, 1, f] Value input = op->getOperand(0); const TensorType new_input_tensor_type = GetTransposedTensorType( - input.getType().cast(), kActivationPermutation); + input.getType().cast(), kNchwToNhwcPermutation); auto input_transpose_op = rewriter.create( op.getLoc(), /*resultType0=*/new_input_tensor_type, /*operand=*/input, - rewriter.getDenseI64ArrayAttr(kActivationPermutation)); + rewriter.getDenseI64ArrayAttr(kNchwToNhwcPermutation)); // Transpose the filter tensor: [o, i, 0, 1] => [0, 1, i, o] Value filter = op->getOperand(1); const TensorType new_filter_tensor_type = GetTransposedTensorType( - filter.getType().cast(), kFilterPermutation); + filter.getType().cast(), kOihwToHwioPermutation); auto filter_transpose_op = rewriter.create( op.getLoc(), /*resultType0=*/new_filter_tensor_type, /*operand=*/filter, - rewriter.getDenseI64ArrayAttr(kFilterPermutation)); + rewriter.getDenseI64ArrayAttr(kOihwToHwioPermutation)); // [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] const auto new_dimension_nums = rewriter.getAttr( @@ -99,7 +100,7 @@ class RewriteNchwConvolutionToNhwc // Determine the shape of the output tensor: [b, f, 0, 1] => [b, 0, 1, f] auto output_tensor_type = op->getResult(0).getType().cast(); const TensorType new_conv_output_tensor_type = - GetTransposedTensorType(output_tensor_type, kOutputPermutation); + GetTransposedTensorType(output_tensor_type, kNchwToNhwcPermutation); // window_strides, padding, lhs_dilation, rhs_dilation, window_reversal are // reused without modification because the ordering of spatial dimensions @@ -125,31 +126,12 @@ class RewriteNchwConvolutionToNhwc auto output_transpose_op = rewriter.create( new_convolution_op.getLoc(), /*resultType0=*/output_tensor_type, /*operand=*/new_convolution_op, - rewriter.getDenseI64ArrayAttr(kOutputReversePermutation)); + rewriter.getDenseI64ArrayAttr(kNhwcToNchwPermutation)); rewriter.replaceAllUsesWith(op, output_transpose_op); } private: - // Permutation to transpose the input tensor from [b, f, 0, 1] to - // [b, 0, 1, f]. - static constexpr std::array kActivationPermutation = {0, 2, 3, 1}; - - // Permutation to transpose the filter tensor from [o, i, 0, 1] to - // [0, 1, i, o]. - static constexpr std::array kFilterPermutation = {2, 3, 1, 0}; - - // Permutation to transpose the output tensor from [b, f, 0, 1] to - // [b, 0, 1, f]. This is used to determine the shape of the new - // `ConvolutionOp`'s output tensor. - static constexpr std::array kOutputPermutation = {0, 2, 3, 1}; - - // Permutation to transpose the output tensor from [b, 0, 1, f] to - // [b, f, 0, 1]. This is used to revert the new output tensor of - // `ConvolutionOp` with a `TransposeOp`. - static constexpr std::array kOutputReversePermutation = {0, 3, 1, - 2}; - // Matches input dimensions corresponding to: [b, f, 0, 1]. bool MatchInputDimensionNumbers( const ConvDimensionNumbersAttr dimension_numbers) const { @@ -183,21 +165,9 @@ class RewriteNchwConvolutionToNhwc TensorType GetTransposedTensorType( const TensorType type, const ArrayRef permutation) const { const SmallVector after_shape = - PermuteShape(type.getShape(), permutation); + Permute(type.getShape(), permutation); return type.cloneWith(after_shape, type.getElementType()); } - - // Permutes the shape according to the permutation. The size of `shape` and - // `permutation` should be equal. - SmallVector PermuteShape(const ArrayRef shape, - const ArrayRef permutation) const { - const int64_t size = shape.size(); - SmallVector after_shape(size); - for (int i = 0; i < size; ++i) { - after_shape[i] = shape[permutation[i]]; - } - return after_shape; - } }; } // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index fb3f5fcb0a21c3..63f6f822dbebdf 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -60,13 +60,17 @@ def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-function "enable-per-channel-quantized-weight", "bool", /*default=*/"true", "Whether to enable per-channel quantized weights.">, + Option<"enable_full_int_quantization_", + "enable-full-int-quantization", + "bool", /*default=*/"false", + "Whether to enable full int quantization, including non compute-heavy ops.">, Option<"mlir_dump_file_name_", "mlir-dump-file-name", "std::optional", /*default=*/"std::nullopt", "MLIR dump file name.">, Option<"enable_weight_only_", "enable-weight-only", "bool", /*default=*/"false", - "Whether to produce weight-only quantized op for dot_general op.">, + "Whether to produce weight-only quantized op for convolution and dot_general op.">, ]; let dependentDialects = [ "mlir::arith::ArithDialect", @@ -102,10 +106,14 @@ def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> { "enable-per-channel-quantized-weight", "bool", /*default=*/"true", "Whether to enable per-channel quantized weights.">, + Option<"enable_full_int_quantization_", + "enable-full-int-quantization", + "bool", /*default=*/"false", + "Whether to apply full int quantization, including non compute-heavy ops.">, Option<"enable_weight_only_", "enable-weight-only", "bool", /*default=*/"false", - "Whether to produce weight-only quantized op for dot_general op.">, + "Whether to produce weight-only quantized op for convolution and dot_general op.">, ]; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", @@ -162,8 +170,22 @@ def NchwConvolutionToNhwcPass : Pass<"stablehlo-nchw-convolution-to-nhwc", "mlir let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; } -def PrepareQuantizeHybridPass : Pass<"stablehlo-prepare-quantize-hybrid", "mlir::func::FuncOp"> { - let summary = "Prepare hybrid quantization for weight-only quantization and dynamic range quantization."; +def DeferActivationTransposePass : Pass<"stablehlo-defer-activation-transpose", "mlir::func::FuncOp"> { + let summary = "Merges stablehlo.transpose for activations."; + let description = [{ + Defers activation transposes (e.g. LHS of `stablehlo.add`) to the output and + optionally inserts `stablehlo.transpose`s to match the shape of operands. + This is useful when recursively pushing down the extra `stablehlo.transpose` + inserted to activation tensors after running `NchwConvolutionToNhwcPass`. + + Currently only converts limited cases that appear in NCHW->NHWC 2D + convolution conversion, to avoid introducing unwanted pessimizations. + }]; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} + +def InsertWeightParamPass : Pass<"stablehlo-insert-weight-param", "mlir::func::FuncOp"> { + let summary = "Insert quantization parameters of weights for weight-only quantization and dynamic range quantization."; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", "TF::TensorFlowDialect", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc index 72702621f6e8b4..10b15f1132fe62 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "absl/algorithm/container.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -49,9 +48,11 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -73,6 +74,9 @@ using ::mlir::stablehlo::GatherOp; using ::mlir::stablehlo::GetDimensionSizeOp; using ::mlir::stablehlo::ReshapeOp; using ::mlir::stablehlo::UniformQuantizeOp; +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::QuantizedType; +using ::stablehlo::quantization::StaticRangePtq; constexpr StringRef kCompositeFuncPrefix = "composite_"; constexpr StringRef kQuantizedFuncPrefix = "quantized_"; @@ -139,22 +143,6 @@ Operation* GetBroadcastedUserOp(Operation* op) { return target_op; } -// Checks if one of the inputs and outputs are quantized. -bool HasQuantizedOperandOrOutput(Operation* call_op) { - SmallVector arg_types; - for (const Value arg : call_op->getOperands()) { - arg_types.push_back(arg.getType()); - } - - SmallVector output_types; - for (const Value output : call_op->getResults()) { - output_types.push_back(output.getType()); - } - - return absl::c_any_of(arg_types, IsQuantizedTensorType) && - absl::c_any_of(output_types, IsQuantizedTensorType); -} - // Gets the corresponding quantized function name from the given function name. // Example: "composite_dot_general_fn_1" => "quantized_dot_general_fn" std::string GetQuantizedFunctionName(const StringRef func_name) { @@ -170,7 +158,7 @@ std::string GetQuantizedFunctionName(const StringRef func_name) { // 3. It should also have the `kEntryFuncAttrName` attribute, which points to // the function that `xla_call_module_op` represents. bool IsQuantizedXlaCallModuleOp(TF::XlaCallModuleOp xla_call_module_op) { - return HasQuantizedOperandOrOutput(xla_call_module_op) && + return !IsOpNotQuantized(xla_call_module_op) && xla_call_module_op->hasAttr(kQuantTraitAttrName) && xla_call_module_op->hasAttr(kEntryFuncAttrName); } @@ -287,6 +275,7 @@ class EntryFuncBodyQuantizationPattern { // Rewrites the `entry_func_op`'s body. virtual void rewrite(func::FuncOp entry_func_op, + const Method& quantization_method, PatternRewriter& rewriter) const = 0; }; @@ -417,63 +406,23 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, } } -template -// Match for tensor manipulation op. -LogicalResult MatchSingularOp(func::FuncOp entry_func_op) { - const auto op_iterator_range = entry_func_op.getOps(); - if (op_iterator_range.empty()) { - LLVM_DEBUG(llvm::dbgs() << "Function does not have " - << SingularOp::getOperationName() << " op.\n"); - return failure(); - } - if (!isa( - (*op_iterator_range.begin()).getResult().getType())) { - LLVM_DEBUG(llvm::dbgs() << SingularOp::getOperationName() - << " op must have ranked tensor type.\n"); - return failure(); - } - return success(); -} - -template -void RewriteSingularOp(func::FuncOp entry_func_op, PatternRewriter& rewriter) { - SingularOp singular_op = *entry_func_op.getOps().begin(); - - const Type operand_type = entry_func_op.getArgumentTypes()[0]; - const Type func_result_type = entry_func_op.getResultTypes()[0]; - - // Get the quantized tensor manipulation op's output type and update. - Value singular_op_result = singular_op.getResult(); - const auto singular_op_result_type = - singular_op_result.getType().cast(); - const ArrayRef singular_op_shape = - singular_op_result_type.getShape(); - const TensorType new_singular_op_result_type = - singular_op_result_type.cloneWith( - singular_op_shape, - getElementTypeOrSelf(operand_type).cast()); - singular_op_result.setType(new_singular_op_result_type); - - // Create requantization op and return. - rewriter.setInsertionPointAfter(singular_op); - CreateAndReturnUniformQuantizeOp(rewriter, *singular_op, entry_func_op, - func_result_type); -} - // Quantizes the entry function's body containing a `DotGeneralOp`. class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeDotGeneralOpPattern( - const bool enable_per_channel_quantized_weight) + const bool enable_per_channel_quantized_weight, + const bool enable_weight_only) : enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight) {} + enable_per_channel_quantized_weight), + enable_weight_only_(enable_weight_only) {} LogicalResult match(func::FuncOp entry_func_op) const override { return MatchGemmStyleOp(entry_func_op); } - void rewrite(func::FuncOp entry_func_op, + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, PatternRewriter& rewriter) const override { + if (enable_weight_only_) return; DotGeneralOp dot_general_op = *entry_func_op.getOps().begin(); const bool should_quantize_per_channel = enable_per_channel_quantized_weight_ && @@ -483,44 +432,92 @@ class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { } private: - const bool enable_per_channel_quantized_weight_; + [[deprecated( + "Do not rely on this field for per-channel quantization. Use `Method` " + "instead.")]] const bool enable_per_channel_quantized_weight_; + // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform + // weight-only quantization. + const bool enable_weight_only_; }; // Quantizes the entry function's body containing a `ConvolutionOp`. class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeConvolutionOpPattern( - const bool enable_per_channel_quantized_weight) + const bool enable_per_channel_quantized_weight, + const bool enable_weight_only) : enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight) {} + enable_per_channel_quantized_weight), + enable_weight_only_(enable_weight_only) {} LogicalResult match(func::FuncOp entry_func_op) const override { return MatchGemmStyleOp(entry_func_op); } - void rewrite(func::FuncOp entry_func_op, + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, PatternRewriter& rewriter) const override { - RewriteGemmStyleOp(entry_func_op, rewriter, - enable_per_channel_quantized_weight_); + if (enable_weight_only_) return; + RewriteGemmStyleOp( + entry_func_op, rewriter, + enable_per_channel_quantized_weight_ && + IsWeightPerChannelQuantized(quantization_method)); + } + + // Returns true if the quantization method indicates per-channel quantization + // for convolution weights. This method specifically matches a quantization + // dimension of 3 for the input index 1. + bool IsWeightPerChannelQuantized(const Method& quantization_method) const { + if (quantization_method.has_static_range_ptq()) { + const StaticRangePtq& static_range_ptq_spec = + quantization_method.static_range_ptq(); + + if (static_range_ptq_spec.input_quantized_types().contains(1)) { + const QuantizedType& weight_quantized_type = + static_range_ptq_spec.input_quantized_types().at(1); + return weight_quantized_type.dimension_specs().dimension() == 3; + } + } + return false; } private: - const bool enable_per_channel_quantized_weight_; + [[deprecated( + "Do not rely on this field for per-channel quantization. Use `Method` " + "instead.")]] const bool enable_per_channel_quantized_weight_; + // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform + // weight-only quantization. + const bool enable_weight_only_; }; -// Quantizes the entry function's body containing a `GatherOp`. -class QuantizeGatherOpPattern : public EntryFuncBodyQuantizationPattern { +template +class QuantizeSingularOpPattern : public EntryFuncBodyQuantizationPattern { public: - explicit QuantizeGatherOpPattern( - const bool enable_per_channel_quantized_weight) {} + explicit QuantizeSingularOpPattern( + const bool enable_per_channel_quantized_weight, + const bool enable_weight_only) {} LogicalResult match(func::FuncOp entry_func_op) const override { - return MatchSingularOp(entry_func_op); + const auto op_iterator_range = entry_func_op.getOps(); + if (op_iterator_range.empty()) { + LLVM_DEBUG(llvm::dbgs() << "Function does not have " + << SingularOpT::getOperationName() << " op.\n"); + return failure(); + } + if (!isa( + (*op_iterator_range.begin()).getResult().getType())) { + LLVM_DEBUG(llvm::dbgs() << SingularOpT::getOperationName() + << " op must have ranked tensor type.\n"); + return failure(); + } + return success(); } - void rewrite(func::FuncOp entry_func_op, + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, PatternRewriter& rewriter) const override { - RewriteSingularOp(entry_func_op, rewriter); + auto singular_op = *entry_func_op.getOps().begin(); + + Value singular_op_result = singular_op.getResult(); + singular_op_result.setType(entry_func_op.getResultTypes()[0]); } }; @@ -528,14 +525,17 @@ class QuantizeGatherOpPattern : public EntryFuncBodyQuantizationPattern { // inputs and outputs of `xla_call_module_op` that are possibly quantized. It // signature (type) is reset to match that of `xla_call_module_op`. // `entry_func_body_quantization_pattern` rewrites the function's body, based on -// the new signature. +// the new signature. `quantization_method` specifies the quantization method +// applied to the quantizable unit `xla_call_module_op` and its corresponding +// function `entry_func_op`. void QuantizeEntryFuncOp( const MLIRContext& ctx, PatternRewriter& rewriter, const TF::XlaCallModuleOp xla_call_module_op, func::FuncOp entry_func_op, - const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { + const EntryFuncBodyQuantizationPattern& body_rewrite_pattern, + const Method& quantization_method) { SetQuantizedFunctionType(rewriter, entry_func_op, xla_call_module_op); - body_rewrite_pattern.rewrite(entry_func_op, rewriter); + body_rewrite_pattern.rewrite(entry_func_op, quantization_method, rewriter); // Rename the function to be clear that the function has been quantized. const std::string quantized_function_name = @@ -549,13 +549,14 @@ void QuantizeEntryFuncOp( void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( const MLIRContext& ctx, PatternRewriter& rewriter, TF::XlaCallModuleOp xla_call_module_op, - const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { + const EntryFuncBodyQuantizationPattern& body_rewrite_pattern, + const Method& quantization_method) { const ModuleOp module_op = xla_call_module_op->getParentOfType(); const SymbolTable symbol_table(module_op); func::FuncOp entry_func_op = GetEntryFuncOp(xla_call_module_op, symbol_table); QuantizeEntryFuncOp(ctx, rewriter, xla_call_module_op, entry_func_op, - body_rewrite_pattern); + body_rewrite_pattern, quantization_method); // Replace the XlaCallModuleOp with a new CallOp. rewriter.setInsertionPoint(xla_call_module_op); @@ -581,10 +582,12 @@ template { public: explicit XlaCallModuleOpToCallOp( - MLIRContext& ctx, const bool enable_per_channel_quantized_weight) + MLIRContext& ctx, const bool enable_per_channel_quantized_weight, + const bool enable_weight_only) : OpRewritePattern(&ctx), enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight) {} + enable_per_channel_quantized_weight), + enable_weight_only_(enable_weight_only) {} LogicalResult match(TF::XlaCallModuleOp op) const override { ModuleOp module_op = op->getParentOfType(); @@ -593,24 +596,44 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { // Ignore unquantized ops. if (!IsQuantizedXlaCallModuleOp(op)) return failure(); + // For weight-only quantization, op should be hybrid quantized. + if (enable_weight_only_ && !IsHybridQuantizedOp(op)) { + return failure(); + } + func::FuncOp entry_func_op = GetEntryFuncOp(op, symbol_table); if (!entry_func_op) { op->emitError("Failed to find a valid entry function."); return failure(); } - return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) + + return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_, + enable_weight_only_) .match(entry_func_op); } void rewrite(TF::XlaCallModuleOp xla_call_module_op, PatternRewriter& rewriter) const override { + // TODO: b/331145946 - Each quantization method should be valid + // (GetQuantizationMethodOrDefault swallows invalid method attribute). Check + // the validity in `match()`. Use accessors to achieve this. + const Method quantization_method = + GetQuantizationMethodOrDefault(xla_call_module_op); + ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( *rewriter.getContext(), rewriter, xla_call_module_op, - FuncBodyRewritePatternT(enable_per_channel_quantized_weight_)); + FuncBodyRewritePatternT(enable_per_channel_quantized_weight_, + enable_weight_only_), + quantization_method); } private: - const bool enable_per_channel_quantized_weight_; + [[deprecated( + "Do not rely on this field for per-channel quantization. Use `Method` " + "instead.")]] const bool enable_per_channel_quantized_weight_; + // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform + // weight-only quantization. + const bool enable_weight_only_; }; // Quantizes op with regions such as stablehlo.reduce_window op. @@ -620,7 +643,7 @@ class QuantizeOpWithRegionPattern : public OpRewritePattern { public: explicit QuantizeOpWithRegionPattern(MLIRContext& ctx) - : OpRewritePattern(&ctx){}; + : OpRewritePattern(&ctx) {}; LogicalResult match(quantfork::DequantizeCastOp op) const final { // Match only when there is one user of the dequantize op. @@ -885,79 +908,50 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { return false; } -class QuantizeHybridDotGeneralPattern - : public EntryFuncBodyQuantizationPattern { +template +class QuantizeWeightOnlyOpPattern : public EntryFuncBodyQuantizationPattern { public: - explicit QuantizeHybridDotGeneralPattern() = default; + explicit QuantizeWeightOnlyOpPattern( + const bool enable_per_channel_quantized_weight) {} LogicalResult match(func::FuncOp entry_func_op) const override { - return MatchGemmStyleOp(entry_func_op); + return MatchGemmStyleOp(entry_func_op); } - void rewrite(func::FuncOp entry_func_op, + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, PatternRewriter& rewriter) const override {} }; -template >> -class HybridXlaCallModuleOpToCallOp - : public OpRewritePattern { - public: - explicit HybridXlaCallModuleOpToCallOp( - MLIRContext& ctx, bool enable_per_channel_quantized_weight) - : OpRewritePattern(&ctx){}; - - LogicalResult match(TF::XlaCallModuleOp op) const override { - ModuleOp module_op = op->getParentOfType(); - SymbolTable symbol_table(module_op); - - // Ignore unquantized ops. - if (!IsHybridQuantizedOp(op) || !IsOpQuantizableStableHlo(op)) { - return failure(); - } - - func::FuncOp entry_func_op = GetEntryFuncOp(op, symbol_table); - if (!entry_func_op) { - op->emitError("Failed to find a valid entry function."); - return failure(); - } - return FuncBodyRewritePatternT().match(entry_func_op); - } - - void rewrite(TF::XlaCallModuleOp xla_call_module_op, - PatternRewriter& rewriter) const override { - ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( - *rewriter.getContext(), rewriter, xla_call_module_op, - FuncBodyRewritePatternT()); - } -}; - -// TODO: b/307620428 - Increase fused op coverage for static range quantization. -void PopulateFusedGemmStylePatterns( +// Compute heavy patterns should be quantized for both server and ODML targets. +void PopulateComputeHeavyPatterns( MLIRContext& ctx, RewritePatternSet& patterns, const bool enable_per_channel_quantized_weight) { patterns.add>( - ctx, enable_per_channel_quantized_weight); + ctx, enable_per_channel_quantized_weight, /*enable_weight_only=*/false); patterns.add>( - ctx, enable_per_channel_quantized_weight); + ctx, enable_per_channel_quantized_weight, /*enable_weight_only=*/false); + // TODO: b/307620772 - Per-channel quantization for gather. + patterns.add>>( + ctx, /*enable_per_channel_quantized_weight=*/false, + /*enable_weight_only=*/false); + // Populate pattern for quantization of ops with regions such as + // `stablehlo.reduce_window` op. + patterns.add(ctx); } -void PopulateQuantizeHybridPatterns(MLIRContext& ctx, +void PopulateAllQuantizablePatterns(MLIRContext& ctx, RewritePatternSet& patterns) { - patterns.add>( - ctx, false); + patterns.add>>( + ctx, /*enable_per_channel_quantized_weight=*/false, + /*enable_weight_only=*/false); } -void PopulateQuantizeOpWithRegionPattern(MLIRContext& ctx, - RewritePatternSet& patterns) { - patterns.add(ctx); -} - -void PopulateQuantizeSingularOpPatterns(MLIRContext& ctx, +void PopulateQuantizeWeightOnlyPatterns(MLIRContext& ctx, RewritePatternSet& patterns) { - // TODO: b/307620772 - Per-channel quantization for gather. - patterns.add>( - ctx, /*enable_per_channel_quantized_weight=*/false); + patterns.add, + XlaCallModuleOpToCallOp>( + ctx, /*enable_per_channel_quantized_weight*/ false, + /*enable_weight_only=*/true); } + } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h index 7b681cc71f71e3..9aa33ee0316ee1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h @@ -60,14 +60,15 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op); // Each matched pattern are rewritten by its quantized alternatives. // // The concrete pattern, extends from this base pattern, can specify whether it -// allows hybrid quantization. If it is allowed, for operand/result that is not -// adjacent to dequantize/quantize op, it remains as float. For operand/result -// that is adjacent to dequantize/quantize, it is quantized. Hybrid quantization -// can be used to generate both weight-only quantization and dynamic range -// quantization. The condition for allowing hybrid quantization or not for an op -// can be specified in the below function: +// allows weight-only quantization. If it is allowed, for operand/result that is +// not adjacent to dequantize/quantize op, it remains as float. For +// operand/result that is adjacent to dequantize/quantize, it is quantized. +// Weight-only quantization can be used to generate both weight-only +// quantization and dynamic range quantization. The condition for allowing +// weight-only quantization or not for an op can be specified in the below +// function: // -// static bool AllowHybridQuantization(Operation& op) +// static bool AllowWeightOnlyQuantization(Operation& op) // // This is a templatized `OpRewritePattern`. // @@ -177,8 +178,8 @@ class StableHloQuantizationPattern : public OpRewritePattern { // If the operand is an integer tensor, then it doesn't require the // DequantizeOp in the pattern. inputs.push_back(operand); - } else if (static_cast(this)->AllowHybridQuantization( - *candidate_op)) { + } else if (static_cast(this) + ->AllowWeightOnlyQuantization(*candidate_op)) { inputs.push_back(operand); } else { return failure(); @@ -214,8 +215,8 @@ class StableHloQuantizationPattern : public OpRewritePattern { // D op in the pattern. outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result.getType()); - } else if (static_cast(this)->AllowHybridQuantization( - *candidate_op)) { + } else if (static_cast(this) + ->AllowWeightOnlyQuantization(*candidate_op)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result.getType()); } else { @@ -249,22 +250,17 @@ class StableHloQuantizationPattern : public OpRewritePattern { } }; -// Gemm Style Op: glossary/gemm. -void PopulateFusedGemmStylePatterns(MLIRContext& ctx, - RewritePatternSet& patterns, - bool enable_per_channel_quantized_weight); +// Populates pattern for compute heavy operations. +void PopulateComputeHeavyPatterns(MLIRContext& ctx, RewritePatternSet& patterns, + bool enable_per_channel_quantized_weight); -// Populates pattern for hybrid quantization. -void PopulateQuantizeHybridPatterns(MLIRContext& ctx, +// Populates conversion patterns for all quantizable ops, including +// ops that are not compute-heavy and data movement ops. +void PopulateAllQuantizablePatterns(MLIRContext& ctx, RewritePatternSet& patterns); -// Populates pattern for quantization of ops with regions such as -// stablehlo.reduce_window op. -void PopulateQuantizeOpWithRegionPattern(MLIRContext& ctx, - RewritePatternSet& patterns); - -// Populates conversion patterns for unary data movement ops. -void PopulateQuantizeSingularOpPatterns(MLIRContext& ctx, +// Populates pattern weight-only quantization. +void PopulateQuantizeWeightOnlyPatterns(MLIRContext& ctx, RewritePatternSet& patterns); } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index 048f0f04cff789..8bb2bd33564481 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -57,7 +57,7 @@ struct StableHloQuantizationBase quantfork::DequantizeCastOp, /*VerifierT=*/void, RootOpT>(ctx) {} - static bool AllowHybridQuantization(Operation& op) { return false; } + static bool AllowWeightOnlyQuantization(Operation& op) { return false; } }; // Quantization rewrite pattern using DQ as the root op. @@ -77,15 +77,22 @@ struct StableHloQuantizationReverse quantfork::QuantizeCastOp>(ctx) {} }; +bool IsHybridQuantizableOp(Operation& op) { + auto call_op = cast(op); + if (call_op == nullptr) return false; + StringRef entry_function_name = GetEntryFunctionName(call_op); + return entry_function_name.contains("conv") || + entry_function_name.contains("dot_general"); +} + // Quantization rewrite pattern using DQ as the root op. -struct StableHloQuantizationHybrid - : public StableHloQuantizationBase { - explicit StableHloQuantizationHybrid(MLIRContext* ctx) - : StableHloQuantizationBase(ctx) {} - - static bool AllowHybridQuantization(Operation& op) { - auto call_op = cast(op); - return call_op && GetEntryFunctionName(call_op).contains("dot_general"); +struct StableHloQuantizationWeightOnly + : public StableHloQuantizationBase { + explicit StableHloQuantizationWeightOnly(MLIRContext* ctx) + : StableHloQuantizationBase(ctx) {} + + static bool AllowWeightOnlyQuantization(Operation& op) { + return IsHybridQuantizableOp(op); } }; @@ -96,9 +103,10 @@ class QuantizePass : public impl::QuantizePassBase { using impl::QuantizePassBase::QuantizePassBase; explicit QuantizePass(const bool enable_per_channel_quantized_weight, - const bool enable_weight_only, - const QuantizationSpecs& quant_specs) { + const bool enable_full_int_quantization, + const bool enable_weight_only) { enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; + enable_full_int_quantization_ = enable_full_int_quantization; enable_weight_only_ = enable_weight_only; } @@ -113,14 +121,17 @@ void QuantizePass::runOnOperation() { RewritePatternSet patterns(&ctx); patterns.add(&ctx); if (enable_weight_only_) { - patterns.add(&ctx); - PopulateQuantizeHybridPatterns(ctx, patterns); + patterns.add(&ctx); + PopulateQuantizeWeightOnlyPatterns(ctx, patterns); } - PopulateQuantizeOpWithRegionPattern(ctx, patterns); - PopulateFusedGemmStylePatterns(ctx, patterns, - enable_per_channel_quantized_weight_); - PopulateQuantizeSingularOpPatterns(ctx, patterns); + PopulateComputeHeavyPatterns(ctx, patterns, + enable_per_channel_quantized_weight_); + + // Quantize all quantizable ops, including ops that are not compute-heavy. + if (enable_full_int_quantization_) { + PopulateAllQuantizablePatterns(ctx, patterns); + } if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { // There are cases where no rewrites happen even if a pattern matches, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index 9478cea46c8795..f3cf92dde359d1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -55,8 +55,9 @@ class QuantizeCompositeFunctionsPass explicit QuantizeCompositeFunctionsPass( const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) { + const bool enable_weight_only, const bool enable_full_int_quantization) { enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; + enable_full_int_quantization_ = enable_full_int_quantization; enable_weight_only_ = enable_weight_only; } @@ -80,7 +81,7 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { options.bit_width_ = 8; if (enable_weight_only_) { - pm.addNestedPass(createPrepareQuantizeHybridPass()); + pm.addNestedPass(createInsertWeightParamPass()); } // PrepareQuantizePass uses SymbolTable to fetch relevant GEMM ops for // determining quantization attributes. This requires module-level context. @@ -89,6 +90,8 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { QuantizePassOptions quantize_options; quantize_options.enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight_; + quantize_options.enable_full_int_quantization_ = + enable_full_int_quantization_; quantize_options.enable_weight_only_ = enable_weight_only_; // QuantizePass modifies FuncOps referenced outside of its given scope // and therefore requires a module-level context. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h index aa9c2106789f27..a8a59d1cd3b46b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h @@ -19,11 +19,22 @@ limitations under the License. namespace mlir::quant::stablehlo::testing { +// Identifies predefined `QuantizationSpecs` for +// `TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass`. The pass +// option argument is specified in line comments for each enum value. +enum class TestQuantizationSpecs { + kEmpty, // empty + kDisableAllDotGeneral, // disable-all-dot-general + kStaticRangePtqToAll, // static-range-ptq-to-all + kStaticRangePtqToComputeHeavy, // static-range-ptq-to-compute-heavy +}; + // Adds generated pass default constructors or options definitions. #define GEN_PASS_DECL // Adds generated pass registration functions. #define GEN_PASS_REGISTRATION #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h.inc" + } // namespace mlir::quant::stablehlo::testing #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TESTING_PASSES_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td index 38d60e94f97e9a..ee525f2deead04 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td @@ -69,6 +69,22 @@ def TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass : This test-only pass is the same as `LiftQuantizableSpotsAsFunctionsPass` but has predefined `QuantizationSpecs` to make FileCheck testing easier. }]; + let options = [ + Option<"quantization_specs_", "quantization-specs", + "mlir::quant::stablehlo::testing::TestQuantizationSpecs", + /*default=*/"mlir::quant::stablehlo::testing::TestQuantizationSpecs::kEmpty", + "Sets one of the predefined `QuantizationSpecs` for testing.", + [{llvm::cl::values( + clEnumValN(mlir::quant::stablehlo::testing::TestQuantizationSpecs::kEmpty, + "empty", "Uses empty (default) QuantizationSpecs."), + clEnumValN(mlir::quant::stablehlo::testing::TestQuantizationSpecs::kDisableAllDotGeneral, + "disable-all-dot-general", "Disables all dot_general ops by matching lifted function names"), + clEnumValN(mlir::quant::stablehlo::testing::TestQuantizationSpecs::kStaticRangePtqToAll, + "static-range-ptq-to-all", "Applies `StaticRangePtq` to all quantizable units."), + clEnumValN(mlir::quant::stablehlo::testing::TestQuantizationSpecs::kStaticRangePtqToComputeHeavy, + "static-range-ptq-to-compute-heavy", "Applies `StaticRangePtq` to only compute heavy units.") + )}]> + ]; let dependentDialects = [ "mlir::func::FuncDialect", "mlir::stablehlo::StablehloDialect", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc index e8cb185cb7b55d..062fbdddd4150d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/Support/TypeID.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep #include "tsl/platform/protobuf.h" // IWYU pragma: keep @@ -39,15 +40,37 @@ using ::tsl::protobuf::TextFormat; // NOLINTNEXTLINE(misc-include-cleaner) - Required for OSS. using ::tsl::protobuf::io::ArrayInputStream; +// Empty (default) `QuantizationSpecs` proto. +constexpr absl::string_view kSpecsEmpty = R"pb(specs + [])pb"; + // Configure `QuantizationSpecs` to disable quantization for all dot_general // quantizable units. -constexpr absl::string_view kSpecsDisableAllDotGeneralByFuncName = +constexpr absl::string_view kSpecsDisableAllDotGeneral = R"pb(specs [ { matcher { function_name { regex: "composite_dot_general_.*" } } method { no_quantization {} } }])pb"; +// Configure `QuantizationSpecs` to apply `StaticRangePtq` to all quantizable +// units. +constexpr absl::string_view kSpecsStaticRangePtqToAll = + R"pb(specs + [ { + matcher { function_name { regex: ".*" } } + method { static_range_ptq {} } + }])pb"; + +// Configure `QuantizationSpecs` to apply `StaticRangePtq` to compute heavy +// units. +constexpr absl::string_view kSpecsStaticRangePtqToComputeHeavy = + R"pb(specs + [ { + matcher { function_name { regex: "^.*(conv|dot|gather).*" } } + method { static_range_ptq {} } + }])pb"; + class TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass : public impl:: TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase< @@ -64,9 +87,24 @@ class TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass void runOnOperation() override; }; +// `TestQuantizationSpecs` -> predefined `QuantizationSpecs` textproto. +absl::string_view GetQuantizationSpecsTextProto( + const TestQuantizationSpecs test_specs) { + switch (test_specs) { + case TestQuantizationSpecs::kEmpty: + return kSpecsEmpty; + case TestQuantizationSpecs::kDisableAllDotGeneral: + return kSpecsDisableAllDotGeneral; + case TestQuantizationSpecs::kStaticRangePtqToAll: + return kSpecsStaticRangePtqToAll; + case TestQuantizationSpecs::kStaticRangePtqToComputeHeavy: + return kSpecsStaticRangePtqToComputeHeavy; + } +} + // Parses a text proto into a `QuantizationSpecs` proto. Returns // `InvalidArgumentError` if `text_proto` is invalid. -absl::StatusOr ParseQuantizationSpecsTextProto( +absl::StatusOr ParseTextProto( const absl::string_view text_proto) { QuantizationSpecs quantization_specs; TextFormat::Parser parser; @@ -81,8 +119,9 @@ void TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass:: runOnOperation() { PassManager pass_manager{&getContext()}; + // Construct `QuantizationSpecs` from the pass option `quantization-specs`. const absl::StatusOr quantization_specs = - ParseQuantizationSpecsTextProto(kSpecsDisableAllDotGeneralByFuncName); + ParseTextProto(GetQuantizationSpecsTextProto(quantization_specs_)); if (!quantization_specs.ok()) { signalPassFailure(); return; @@ -93,7 +132,6 @@ void TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass:: if (failed(pass_manager.run(getOperation()))) { signalPassFailure(); - return; } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_pre_calibration_component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_pre_calibration_component.cc index 06b53035c80c7a..0c41771a5c43b0 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_pre_calibration_component.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_pre_calibration_component.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Support/TypeID.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep #include "stablehlo/dialect/VhloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" @@ -34,6 +35,8 @@ namespace mlir::quant::stablehlo::testing { namespace { +using ::stablehlo::quantization::ExpandPresets; +using ::stablehlo::quantization::PopulateDefaults; using ::stablehlo::quantization::QuantizationConfig; class TestPreCalibrationComponentPass @@ -52,7 +55,10 @@ void TestPreCalibrationComponentPass::runOnOperation() { // Simply runs the PreCalibrationComponent with a default configuration. PreCalibrationComponent component(&ctx); - if (!component.Run(module_op, QuantizationConfig::default_instance()).ok()) { + QuantizationConfig quantization_config{}; + quantization_config.mutable_static_range_ptq_preset(); + quantization_config = ExpandPresets(PopulateDefaults(quantization_config)); + if (!component.Run(module_op, quantization_config).ok()) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index a9bd3a713ede7c..2b20cc48a89d69 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -133,6 +133,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:config", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:weight_only_ptq", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py index 80ccf81c33b9b9..80a2c560ef865b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -64,7 +64,6 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): ([10, 1, 1024], [10, 1024, 3]), ([2, 3, 1, 1024], [2, 3, 1024, 3]), ), - 'rng_seed': (1230, 1231, 1232, 1233), }]) ) @test_util.run_in_graph_and_eager_modes @@ -73,7 +72,6 @@ def test_matmul_ptq_model( bias_fn: Optional[ops.Operation], activation_fn: Optional[ops.Operation], dim_sizes: Sequence[int], - rng_seed: int, ): lhs_dim_size, rhs_dim_size = dim_sizes input_shape = (*lhs_dim_size,) @@ -87,7 +85,7 @@ def test_matmul_ptq_model( activation_fn, ) - rng = np.random.default_rng(rng_seed) + rng = np.random.default_rng(seed=42) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( np.float32 @@ -144,6 +142,14 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # values are arbitrary. self.assertAllClose(new_outputs, expected_outputs, rtol=0.03, atol=0.2) + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.65, + ) + @parameterized.parameters( testing.parameter_combinations([{ 'same_scale_op': ( @@ -156,14 +162,12 @@ def data_gen() -> repr_dataset.RepresentativeDataset: 'slice', 'transpose', ), - 'rng_seed': (0, 11, 222, 3333), }]) ) @test_util.run_in_graph_and_eager_modes def test_matmul_and_same_scale_ptq_model( self, same_scale_op: str, - rng_seed: int, ): input_shape = (2, 3, 1, 1024) filter_shape = (2, 3, 1024, 3) @@ -176,7 +180,7 @@ def test_matmul_and_same_scale_ptq_model( same_scale_op, ) - rng = np.random.default_rng(rng_seed) + rng = np.random.default_rng(seed=42) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( np.float32 @@ -225,6 +229,14 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # values are arbitrary. self.assertAllClose(new_outputs, expected_outputs, rtol=0.03, atol=0.2) + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.65, + ) + @parameterized.parameters( testing.parameter_combinations([{ 'same_scale_op': ( @@ -233,7 +245,6 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # TODO: b/326242075 - Support other same-scale ops. ), 'dim_sizes': (([None, 1024], [1024, 3]),), - 'rng_seed': (0, 11, 222, 3333), }]) ) @test_util.run_in_graph_and_eager_modes @@ -241,7 +252,6 @@ def test_matmul_and_same_scale_ptq_model_dynamic( self, same_scale_op: str, dim_sizes: Sequence[int], - rng_seed: int, ): input_dim_size, filter_dim_size = dim_sizes input_shape = (*input_dim_size,) @@ -255,7 +265,7 @@ def test_matmul_and_same_scale_ptq_model_dynamic( same_scale_op, ) - rng = np.random.default_rng(rng_seed) + rng = np.random.default_rng(seed=42) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( np.float32 @@ -304,6 +314,14 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # values are arbitrary. self.assertAllClose(new_outputs, expected_outputs, rtol=0.03, atol=0.2) + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.6, + ) + @parameterized.parameters( testing.parameter_combinations([{ 'bias_fn': ( @@ -315,7 +333,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: nn_ops.relu, nn_ops.relu6, ), - 'has_batch_norm': (False,), + 'has_batch_norm': (False, True), 'input_shape_dynamic': ( False, True, @@ -324,7 +342,6 @@ def data_gen() -> repr_dataset.RepresentativeDataset: False, True, ), - 'rng_seed': (10, 11, 12, 13), }]) ) @test_util.run_in_graph_and_eager_modes @@ -335,7 +352,6 @@ def test_conv_ptq_model( has_batch_norm: bool, input_shape_dynamic: bool, enable_per_channel_quantized_weight: bool, - rng_seed: int, dilations: Sequence[int] = None, ): input_shape = (None, 3, 4, 3) if input_shape_dynamic else (1, 3, 4, 3) @@ -351,9 +367,18 @@ def test_conv_ptq_model( strides, dilations, ) + # TODO(b/331809306): investigate why these tests fail. + # skip these test cases. + if ( + bias_fn is None + and has_batch_norm + and input_shape_dynamic + and enable_per_channel_quantized_weight + ): + return # Generate model input data. - rng = np.random.default_rng(rng_seed) + rng = np.random.default_rng(seed=42) static_input_shape = [dim if dim is not None else 2 for dim in input_shape] input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( @@ -412,19 +437,25 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # values are arbitrary. self.assertAllClose(new_outputs, expected_outputs, rtol=0.02, atol=0.05) + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.61, + ) + @parameterized.parameters( testing.parameter_combinations([{ 'equation': ( 'abc,cde->abde', 'abc,dce->abde', ), - 'rng_seed': (82, 82732, 4444, 14), }]) ) def test_einsum_ptq_model( self, equation: str, - rng_seed: int, ): _, y_shape, bias_shape, x_signature, y_signature = ( self._prepare_sample_einsum_datashapes(equation, use_bias=True) @@ -440,7 +471,7 @@ def test_einsum_ptq_model( ) # Generate model input data. - rng = np.random.default_rng(rng_seed) + rng = np.random.default_rng(seed=42) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=x_signature).astype('f4') ) @@ -489,6 +520,14 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # values are arbitrary. self.assertAllClose(new_outputs, expected_outputs, rtol=0.02, atol=0.04) + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.65, + ) + def test_when_preset_not_srq_raises_error(self): self._create_matmul_model( input_shape=(1, 1024), @@ -573,6 +612,14 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # are negligible numeric difference. self.assertAllClose(new_outputs, expected_outputs, rtol=0.000001) + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.4, + ) + @test_util.run_in_graph_and_eager_modes def test_ptq_selective_denylist(self): """Tests that the op is not quantized when no quantization is enabled.""" @@ -667,6 +714,14 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # Indirectly tests that the model is only partially quantized. self.assertAllClose(new_outputs, expected_outputs, rtol=0.011) + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.55, + ) + @test_util.run_in_graph_and_eager_modes def test_ptq_quantization_method_not_applied_when_matcher_mismatch(self): """Tests that quantization method is not applied to unmatched units.""" @@ -737,6 +792,14 @@ def data_gen() -> repr_dataset.RepresentativeDataset: self.assertAllClose(new_outputs, expected_outputs, rtol=0.04) self.assertNotAllClose(new_outputs, expected_outputs, rtol=0.00001) + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.4, + ) + @test_util.run_all_in_graph_and_eager_modes class CalibrationOptionsTest(quantize_model_test_base.QuantizedModelTest): @@ -746,47 +809,49 @@ class CalibrationOptionsTest(quantize_model_test_base.QuantizedModelTest): (default in TF2) to ensure support for when TF2 is disabled. """ - # TODO(b/307621353): add CALIBRATION_METHOD_HISTOGRAM_PERCENTILE. @parameterized.parameters( { - 'calibration_options': - qc.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX # pylint: disable=line-too-long - ) + 'calibration_options': qc.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX + ) }, { - 'calibration_options': - qc.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_AVERAGE_MIN_MAX # pylint: disable=line-too-long + 'calibration_options': qc.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_AVERAGE_MIN_MAX + ), + }, + { + 'calibration_options': qc.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, + calibration_parameters=qc.CalibrationOptions.CalibrationParameters( + initial_num_bins=10, ), + ), }, { - 'calibration_options': - qc.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, # pylint: disable=line-too-long - calibration_parameters=qc.CalibrationOptions.CalibrationParameters( # pylint: disable=line-too-long - initial_num_bins=10, - ), + 'calibration_options': qc.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, + calibration_parameters=qc.CalibrationOptions.CalibrationParameters( + initial_num_bins=10, ), + ), }, { - 'calibration_options': - qc.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, # pylint: disable=line-too-long - calibration_parameters=qc.CalibrationOptions.CalibrationParameters( # pylint: disable=line-too-long - initial_num_bins=10, - ), + 'calibration_options': qc.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, + calibration_parameters=qc.CalibrationOptions.CalibrationParameters( + initial_num_bins=10, ), + ), }, { - 'calibration_options': - qc.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, # pylint: disable=line-too-long - calibration_parameters=qc.CalibrationOptions.CalibrationParameters( # pylint: disable=line-too-long - initial_num_bins=10, - ), + 'calibration_options': qc.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, + calibration_parameters=qc.CalibrationOptions.CalibrationParameters( + initial_num_bins=10, ), - } + ), + }, ) @test_util.run_in_graph_and_eager_modes def test_conv_ptq_model_by_calibration_options( @@ -796,7 +861,7 @@ def test_conv_ptq_model_by_calibration_options( bias_fn = nn_ops.bias_add activation_fn = nn_ops.relu6 enable_per_channel_quantized_weight = False - has_batch_norm = False + has_batch_norm = True dilations = None input_shape = (1, 3, 4, 3) filter_shape = (2, 3, 3, 2) @@ -814,18 +879,14 @@ def test_conv_ptq_model_by_calibration_options( # Generate model input data. input_data = ops.convert_to_tensor( - np.random.uniform(low=0.0, high=10, size=input_shape).astype( - 'f4' - ) + np.random.uniform(low=0.0, high=10, size=input_shape).astype('f4') ) def data_gen() -> repr_dataset.RepresentativeDataset: for _ in range(100): yield { 'input_tensor': ops.convert_to_tensor( - np.random.uniform(low=0, high=10, size=input_shape).astype( - 'f4' - ) + np.random.uniform(low=0, high=10, size=input_shape).astype('f4') ), } @@ -865,6 +926,199 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # values are arbitrary. self.assertAllClose(new_outputs, expected_outputs, rtol=0.02, atol=0.5) + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.46, + ) + + +class WeightOnlyQuantizationTest(quantize_model_test_base.QuantizedModelTest): + + @parameterized.parameters( + testing.parameter_combinations([{ + 'bias_fn': ( + None, + nn_ops.bias_add, + ), + 'activation_fn': ( + None, + nn_ops.relu, + nn_ops.relu6, + ), + 'dim_sizes': ( + # tf.MatMul cases. + ([None, 1024], [1024, 3]), # dynamic batch dim. + ([1, 1024], [1024, 3]), + # tf.BatchMatMul cases. + ([10, 1, 1024], [10, 1024, 3]), + ([2, 3, 1, 1024], [2, 3, 1024, 3]), + ), + }]) + ) + @test_util.run_in_graph_and_eager_modes + def test_matmul_weight_only_model( + self, + bias_fn: Optional[ops.Operation], + activation_fn: Optional[ops.Operation], + dim_sizes: Sequence[int], + ): + lhs_dim_size, rhs_dim_size = dim_sizes + input_shape = (*lhs_dim_size,) + filter_shape = (*rhs_dim_size,) + static_input_shape = [dim if dim is not None else 2 for dim in input_shape] + model = self._create_matmul_model( + input_shape, + filter_shape, + self._input_saved_model_path, + bias_fn, + activation_fn, + ) + + rng = np.random.default_rng(1234) + input_data = ops.convert_to_tensor( + rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( + np.float32 + ) + ) + + config = qc.QuantizationConfig( + weight_only_preset=qc.WeightOnlyPreset(), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + expected_outputs = model.matmul(input_data) + + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + + new_outputs = root.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + # Tests that the quantized graph outputs similar values. The rtol and atol + # values are arbitrary. + self.assertAllClose(new_outputs, expected_outputs, rtol=0.03, atol=0.2) + + module_str = self._extract_first_xla_call_module_op( + self._output_saved_model_path + ) + + # Tests that the output graph contains subtract and multiply for + # dequantization. + self.assertTrue(re.search('stablehlo.subtract', module_str)) + self.assertTrue(re.search('stablehlo.multiply', module_str)) + # Tests that the output graph contains float dot_general. + self.assertTrue( + re.search('stablehlo.dot_general.*xf32>.*xf32>.*xf32>', module_str) + ) + + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.3, + ) + + @parameterized.parameters( + testing.parameter_combinations([{ + 'bias_fn': ( + None, + nn_ops.bias_add, + ), + 'activation_fn': ( + None, + nn_ops.relu, + nn_ops.relu6, + ), + 'has_batch_norm': (False,), + 'input_shape_dynamic': ( + False, + True, + ), + }]) + ) + @test_util.run_in_graph_and_eager_modes + def test_conv_weight_only_model( + self, + bias_fn: Optional[ops.Operation], + activation_fn: Optional[ops.Operation], + has_batch_norm: bool, + input_shape_dynamic: bool, + dilations: Sequence[int] = None, + ): + input_shape = (None, 3, 4, 3) if input_shape_dynamic else (1, 3, 4, 3) + filter_shape = (2, 3, 3, 2) + strides = (1, 1, 1, 1) + model = self._create_conv2d_model( + input_shape, + filter_shape, + self._input_saved_model_path, + bias_fn, + activation_fn, + has_batch_norm, + strides, + dilations, + ) + + rng = np.random.default_rng(1234) + static_input_shape = [dim if dim is not None else 2 for dim in input_shape] + input_data = ops.convert_to_tensor( + rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( + np.float32 + ) + ) + + config = qc.QuantizationConfig( + weight_only_preset=qc.WeightOnlyPreset(), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + expected_outputs = model.conv2d(input_data) + + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + + new_outputs = root.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + # Tests that the quantized graph outputs similar values. The rtol and atol + # values are arbitrary. + self.assertAllClose(new_outputs, expected_outputs, rtol=0.03, atol=0.2) + + module_str = self._extract_first_xla_call_module_op( + self._output_saved_model_path + ) + + # Tests that the output graph contains subtract and multiply for + # dequantization. + self.assertTrue(re.search('stablehlo.subtract', module_str)) + self.assertTrue(re.search('stablehlo.multiply', module_str)) + # Tests that the output graph contains float dot_general. + self.assertTrue( + re.search('stablehlo.convolution.*xf32>.*xf32>.*xf32>', module_str) + ) + + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.35, + ) + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py index 8a5f2529c56e22..d71c89e15d313f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py @@ -284,13 +284,13 @@ def conv2d(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: ) if bias_fn is not None: out = nn_ops.bias_add(out, self.bias) - if activation_fn is not None: - out = activation_fn(out) if has_batch_norm: # Fusing is supported for non-training case. out, _, _, _, _, _ = nn_ops.fused_batch_norm_v3( out, scale, offset, mean, variance, is_training=False ) + if activation_fn is not None: + out = activation_fn(out) return {'output': out} model = ConvModel() diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc index 6ee6f9ac317ce0..3269006ec06dbb 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc @@ -27,8 +27,10 @@ namespace py = pybind11; namespace { +using ::stablehlo::quantization::pywrap::PywrapExpandPresets; using ::stablehlo::quantization::pywrap::PywrapPopulateDefaults; using ::stablehlo::quantization::pywrap::PywrapQuantizeStaticRangePtq; +using ::stablehlo::quantization::pywrap::PywrapQuantizeWeightOnlyPtq; } // namespace @@ -60,6 +62,27 @@ PYBIND11_MODULE(pywrap_quantization, m) { py::arg("py_function_library")); // LINT.ThenChange(pywrap_quantization.pyi:static_range_ptq) + // If the function signature changes, likely its corresponding .pyi type + // hinting should also change. + // LINT.IfChange(weight_only_ptq) + m.def("weight_only_ptq", &PywrapQuantizeWeightOnlyPtq, + R"pbdoc( + Runs weight-only Quantization on a SavedModel at `src_saved_model_path` + and saves the resulting model to `dst_saved_model_path`. + + The user should pass a serialized `QuantizationConfig` for the + `quantization_config_serialized` argument, and a signature key -> + serialized `SignatureDef` mapping for the `signature_def_map_serialized` + argument. + + Raises `StatusNotOk` exception if when the run was unsuccessful. + )pbdoc", + py::arg("src_saved_model_path"), py::arg("dst_saved_model_path"), + py::arg("quantization_config_serialized"), py::kw_only(), + py::arg("signature_keys"), py::arg("signature_def_map_serialized"), + py::arg("py_function_library")); + // LINT.ThenChange(pywrap_quantization.pyi:weight_only_ptq) + // If the function signature changes, likely its corresponding .pyi type // hinting should also change. // LINT.IfChange(populate_default_configs) @@ -71,5 +94,19 @@ PYBIND11_MODULE(pywrap_quantization, m) { default values to fields that the user did not explicitly specify. )pbdoc", py::arg("user_provided_config_serialized")); - // LINT.ThenChange(pywrap_quantization.pyi:static_range_ptq) + // LINT.ThenChange(pywrap_quantization.pyi:populate_default_configs) + + // If the function signature changes, likely its corresponding .pyi type + // hinting should also change. + // LINT.IfChange(expand_preset_configs) + m.def("expand_preset_configs", &PywrapExpandPresets, R"pbdoc( + Expands presets to other fields in `QuantizationConfig`. + + Each preset is expressible by other fields in `QuantizationConfig`. + Returns a copy of `QuantizationConfig` (serialized) where the fields are + expanded from presets. If no preset has been set, it is a no-op and + returns the same copy of the input. + )pbdoc", + py::arg("quantization_config_serialized")); + // LINT.ThenChange(pywrap_quantization.pyi:expand_preset_configs) } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi index f46f44b218ee84..e79e2db2c2ac8f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi @@ -17,7 +17,6 @@ from typing import Any from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd - # LINT.IfChange(static_range_ptq) def static_range_ptq( src_saved_model_path: str, @@ -31,6 +30,18 @@ def static_range_ptq( # LINT.ThenChange() +# LINT.IfChange(weight_only_ptq) +def weight_only_ptq( + src_saved_model_path: str, + dst_saved_model_path: str, + quantization_config_serialized: bytes, + *, + signature_keys: list[str], + signature_def_map_serialized: dict[str, bytes], + py_function_library: py_function_lib.PyFunctionLibrary, +) -> Any: ... # Status + +# LINT.ThenChange() # LINT.IfChange(populate_default_configs) def populate_default_configs( @@ -38,3 +49,10 @@ def populate_default_configs( ) -> bytes: ... # QuantizationConfig # LINT.ThenChange() + +# LINT.IfChange(expand_preset_configs) +def expand_preset_configs( + quantization_config_serialized: bytes, +) -> bytes: ... # QuantizationConfig + +# LINT.ThenChange() diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.cc b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.cc index 4fe33c60147df7..3b5ece120bdeb0 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.cc @@ -22,12 +22,14 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" namespace stablehlo::quantization::pywrap { using ::mlir::quant::stablehlo::QuantizeStaticRangePtq; +using ::mlir::quant::stablehlo::QuantizeWeightOnlyPtq; using ::tensorflow::SignatureDef; using ::tensorflow::quantization::PyFunctionLibrary; @@ -46,9 +48,24 @@ absl::Status PywrapQuantizeStaticRangePtq( py_function_library); } +absl::Status PywrapQuantizeWeightOnlyPtq( + absl::string_view src_saved_model_path, + absl::string_view dst_saved_model_path, const QuantizationConfig& config, + const std::vector& signature_keys, + const absl::flat_hash_map& signature_def_map, + const PyFunctionLibrary& py_function_library) { + return QuantizeWeightOnlyPtq(src_saved_model_path, dst_saved_model_path, + config, signature_keys, signature_def_map, + py_function_library); +} + QuantizationConfig PywrapPopulateDefaults( const QuantizationConfig& user_provided_config) { return PopulateDefaults(user_provided_config); } +QuantizationConfig PywrapExpandPresets(const QuantizationConfig& config) { + return ExpandPresets(config); +} + } // namespace stablehlo::quantization::pywrap diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.h b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.h index 0f1af29424e79d..ff724abaac5dee 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.h @@ -40,11 +40,25 @@ absl::Status PywrapQuantizeStaticRangePtq( signature_def_map, const tensorflow::quantization::PyFunctionLibrary& py_function_library); +// Function used by the pywrap_quantization module to mirror +// `::mlir::quant::stablehlo::QuantizeWeightOnlyPtq`. +absl::Status PywrapQuantizeWeightOnlyPtq( + absl::string_view src_saved_model_path, + absl::string_view dst_saved_model_path, const QuantizationConfig& config, + const std::vector& signature_keys, + const absl::flat_hash_map& + signature_def_map, + const tensorflow::quantization::PyFunctionLibrary& py_function_library); + // Function used by the pywrap_quantization module to mirror // `::stablehlo::quantization::PopulateDefaults`. QuantizationConfig PywrapPopulateDefaults( const QuantizationConfig& user_provided_config); +// Function used by the pywrap_quantization module to mirror +// `::stablehlo::quantization::ExpandPresets`. +QuantizationConfig PywrapExpandPresets(const QuantizationConfig& config); + } // namespace stablehlo::quantization::pywrap #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PYTHON_PYWRAP_QUANTIZATION_LIB_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py index 6938000deaae0e..aa3745a3fdd453 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py @@ -63,15 +63,24 @@ def quantize_saved_model( if not ( config.HasField('static_range_ptq_preset') and len(config.static_range_ptq_preset.representative_datasets) == 1 - ): + ) and not config.HasField('weight_only_preset'): raise ValueError( '`quantize_saved_model` currently only supports static-range PTQ with a' - ' single signature.' + ' single signature or weight-only quantization.' ) + # Updates user-provided `QuantizationConfig`s for the internal quantization + # pipeline to work with. + print('=== User-provided QuantizationConfig ===') + print(config) config = qc.QuantizationConfig.FromString( pywrap_quantization.populate_default_configs(config.SerializeToString()) ) + config = qc.QuantizationConfig.FromString( + pywrap_quantization.expand_preset_configs(config.SerializeToString()) + ) + print('=== Updated QuantizationConfig ===') + print(config) signature_def_map = save_model.get_signatures_from_saved_model( src_saved_model_path, @@ -80,11 +89,21 @@ def quantize_saved_model( ) signature_def_map_serialized = _serialize_signature_def_map(signature_def_map) - pywrap_quantization.static_range_ptq( - src_saved_model_path, - dst_saved_model_path, - quantization_config_serialized=config.SerializeToString(), - signature_keys=list(signature_def_map.keys()), - signature_def_map_serialized=signature_def_map_serialized, - py_function_library=py_function_lib.PyFunctionLibrary(), - ) + if config.HasField('static_range_ptq_preset'): + pywrap_quantization.static_range_ptq( + src_saved_model_path, + dst_saved_model_path, + quantization_config_serialized=config.SerializeToString(), + signature_keys=list(signature_def_map.keys()), + signature_def_map_serialized=signature_def_map_serialized, + py_function_library=py_function_lib.PyFunctionLibrary(), + ) + elif config.HasField('weight_only_preset'): + pywrap_quantization.weight_only_ptq( + src_saved_model_path, + dst_saved_model_path, + quantization_config_serialized=config.SerializeToString(), + signature_keys=list(signature_def_map.keys()), + signature_def_map_serialized=signature_def_map_serialized, + py_function_library=py_function_lib.PyFunctionLibrary(), + ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto index 81aff6e46d5850..efdceebd6c2008 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto @@ -28,23 +28,58 @@ message RepresentativeDatasetConfig { } // Preset config for static-range post-training quantization (PTQ). +// // Minimal user input about representative datasets is required. Representative // datasets are required for static-range PTQ to retrieve quantization // statistics via calibration. -// Next ID: 3 +// +// This preset is equivalent to the following `QuantizationSpecs`: +// +// ``` +// specs {matcher {function_name {regex: ".*"}} method {static_range_ptq {}}} +// specs { +// matcher {function_name {regex: "composite_conv.*"}} +// method {static_range_ptq { +// input_quantized_types { +// key: 1 +// value {dimension_specs {dimension: 3}}} +// }} +// } +// ``` +// +// This preset: +// * Applies per-channel quantization for weights (input index 1) of +// convolution quantizable unit family. The quantization dimension is 3, the +// channel dimension, which assumes the weight tensor is in NHWC format. +// * Applies static-range PTQ for all other ops. +// +// Next ID: 4 message StaticRangePtqPreset { // Configures representative dataset. Each item corresponds to a // representative dataset used to calibrate a function. + // If `QuantizationConfig.calibration_options.representative_datasets` is also + // provided then this field will be ignored. repeated RepresentativeDatasetConfig representative_datasets = 1; // NOTE: This field will be deprecated. - // Granularity should be controlled in custom configuration, deprecating - // this field once available. - // If set true, enable channel-wise quantization for all supported ops. - // This value is true by default. - bool enable_per_channel_quantized_weight = 2; + // Granularity should be controlled using `Method`, deprecating this field + // once available. + // + // If set to true, enable channel-wise quantization for: + // * Convolution ops: When the attached `Method` also specifies per-channel + // quantization. + // * Non-convolution ops: All + // + // Default value: true + bool enable_per_channel_quantized_weight = 2 [deprecated = true]; + + // Whether to quantize all quantizable ops or only compute-heavy ops. + bool enable_full_int_quantization = 3; } +// Applies int8 per-tensor weight-only quantization for all dot_general op. +message WeightOnlyPreset {} + // Metadata specific to the input TensorFlow SavedModel, which may be required // to identify the specific MetaGraphDef to quantize, for example. // Next ID: 2 @@ -63,10 +98,66 @@ message PipelineConfig { optional bool unpack_quantized_types = 1; } +// Represents a single quantizable unit, a (nearly) minimum unit of work when +// applying quantization. It may correspond to a single or multiple ops. +// Next ID: 2 +message QuantizableUnit { + // Name of the `FuncOp` symbol corresponding to the "lifted function", + // representing a single quantizable unit. This value is guaranteed to be + // unique across a single `ModuleOp`. + string name = 1; +} + +// Represents a quantization result of a single `QuantizableUnit`. It is +// essentially a `(QuantizableUnit, Method)` pair, where the `Method` +// corresponds to the quantization method eventually applied to the +// `QuantizableUnit`. +// Next ID: 3 +message QuantizationResult { + QuantizableUnit quantizable_unit = 1; + Method method = 2; +} + +// A series of `QuantizationResult`s. See `QuantizationResult` for details. +// Next ID: 2 +message QuantizationResults { + repeated QuantizationResult results = 1; +} + +message QuantizedDimension { + int32 dimension = 1; // Should be less than the rank of the quantized tensor. +} + +// Corresponds to StableHLO's `QuantizedTensorElementType`. Type parameters such +// as `QuantizationParameters` is omitted because they are determined during +// quantization. +// See https://github.com/openxla/stablehlo/blob/main/docs/spec.md#types for +// details. +// +// Currently only supports specifying quantization granularity (e.g. for +// per-channel quantization). +// TODO: b/331144430 - Support specifying storage types. +message QuantizedType { + // Specifies the granularity of quantization parameters for each dimension of + // a quantized tensor. If specified, per-channel quantization is applied. If + // not specified, per-tensor quantization is applied. + // TODO: Make it a `repeated` field to be able to express multi-channel / + // sub-channel quantization. + QuantizedDimension dimension_specs = 1; +} + // A quantization method representing "do not quantize". Mostly used for // denylisting quantizable units from quantization. message NoQuantization {} +// Configurations for static-range post-training quantization method on a +// quantizable unit. +message StaticRangePtq { + // Operand index -> QuantizedType mapping. Operands that are not specified + // here will be quantized with best effort. + map input_quantized_types = 1; +} + // Represents a matching method that matches quantizable units by lifted // functions' names. message FunctionNameMatcherSpec { @@ -84,7 +175,10 @@ message MatcherSpec { // Specifies how to quantize matched quantizable units. message Method { - NoQuantization no_quantization = 1; + oneof method { + NoQuantization no_quantization = 1; + StaticRangePtq static_range_ptq = 2; + } } // A QuantizationSpec is essentially a (matcher spec, quantization method) pair, @@ -158,9 +252,10 @@ message DebuggerConfig { } // Defines various calibration options. +// Next ID: 4 message CalibrationOptions { // Configurations for calibration methods. - // NEXT ID: 7 + // Next ID: 7 enum CalibrationMethod { CALIBRATION_METHOD_UNSPECIFIED = 0; // Use the min, max values of all sample datasets. @@ -185,7 +280,7 @@ message CalibrationOptions { } // Parameters required for calibration. - // NEXT ID: 4 + // Next ID: 4 message CalibrationParameters { // The number of bins when histogram is initialized. It can be increased // because histogram is dynamically expanded by sample inputs. @@ -200,7 +295,7 @@ message CalibrationOptions { } // Determines how to calibrate. - // The default calibration method is MIN_MAX. + // Default value: CALIBRATION_METHOD_MIN_MAX CalibrationMethod calibration_method = 1; // Defines the parameters required for calibration. Parameters such as the @@ -208,21 +303,26 @@ message CalibrationOptions { // MIN_MAX and AVERAGE_MIN_MAX don't require this parameter and methods // starting with HISTOGRAM require this parameter. CalibrationParameters calibration_parameters = 2; + + // Configures representative dataset. Each item corresponds to a + // representative dataset used to calibrate a function. + repeated RepresentativeDatasetConfig representative_datasets = 3; } // Quantization configuration for StableHLO Quantizer. This is the primary // message containing all configurable options. -// Next ID: 7 +// Next ID: 8 message QuantizationConfig { // Config presets provide predefined popular or common quantization specs. // Lightweight users may choose one of the presets for quick experiments. Each - // preset is completely represented by `QuantizationSpecs`. When extra entries - // in `QuantizationSpecs` are provided along with a preset, then the preset - // will be overridden for the quantizable units matched by those additional - // `QuantizationSpec`s. + // preset is completely represented by other fields in `QuantizationConfig`. + // + // When extra entries in `QuantizationSpecs` are provided along with a preset, + // then those entries will take precedence. oneof preset { // Performs best-effort static-range post-training quantization (PTQ). StaticRangePtqPreset static_range_ptq_preset = 1; + WeightOnlyPreset weight_only_preset = 7; } // TF SavedModel specific information for the input model. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD index db4bc1a92483c1..55a41d4ce76072 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD @@ -1,6 +1,6 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir index 6a5b58a7ba7b64..1fe56cde49601d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir @@ -8,10 +8,10 @@ func.func @main(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { } // CHECK: @main(%[[ARG_0:.+]]: tensor<1x4xf32>) -> tensor<1x3xf32> // CHECK-DAG: %[[CST:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> -// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]] = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {calibration_method = 0 : i32, {{.*}}} : (tensor<1x4xf32>) -> tensor<1x4xf32> +// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]] = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {{.*}} : (tensor<1x4xf32>) -> tensor<1x4xf32> // CHECK: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[CST]]) // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" -// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {calibration_method = 0 : i32, {{.*}}} : (tensor<1x3xf32>) -> tensor<1x3xf32> +// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {{.*}} : (tensor<1x3xf32>) -> tensor<1x3xf32> // CHECK: return %[[CUSTOM_AGGREGATOR_1]] : tensor<1x3xf32> // CHECK: } // CHECK: } @@ -28,10 +28,10 @@ func.func @serving_default(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { } // CHECK: @serving_default(%[[ARG_0:.+]]: tensor<1x4xf32>) -> tensor<1x3xf32> // CHECK-DAG: %[[CST:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> -// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]] = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {calibration_method = 0 : i32, {{.*}}} : (tensor<1x4xf32>) -> tensor<1x4xf32> +// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]] = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {{.*}} : (tensor<1x4xf32>) -> tensor<1x4xf32> // CHECK: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[CST]]) // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" -// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {calibration_method = 0 : i32, {{.*}}} : (tensor<1x3xf32>) -> tensor<1x3xf32> +// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {{.*}} : (tensor<1x3xf32>) -> tensor<1x3xf32> // CHECK: return %[[CUSTOM_AGGREGATOR_1]] : tensor<1x3xf32> // CHECK: } // CHECK: } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/tf_to_stablehlo.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/tf_to_stablehlo.mlir index 55ff087240a5e0..240b10d8438431 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/tf_to_stablehlo.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/tf_to_stablehlo.mlir @@ -1,17 +1,16 @@ // RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics -stablehlo-test-tf-to-stablehlo | FileCheck %s -func.func @fused_batchnorm_no_training() -> (tensor<1x1x2x8xf32>) { - %cst_0 = "tf.Const"() {value = dense<[[[[0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2], [0.3, 0.4, 0.3, 0.4, 0.3, 0.4, 0.3, 0.4]]]]> : tensor<1x1x2x8xf32>} : () -> tensor<1x1x2x8xf32> - %cst_1 = "tf.Const"() {value = dense<[0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2]> : tensor<8xf32>} : () -> tensor<8xf32> - %cst_2 = "tf.Const"() {value = dense<[0.3, 0.4, 0.3, 0.4, 0.3, 0.4, 0.3, 0.4]> : tensor<8xf32>} : () -> tensor<8xf32> - %0:6 = "tf.FusedBatchNormV3"(%cst_0, %cst_1, %cst_2, %cst_1, %cst_2) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<1x1x2x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<1x1x2x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - func.return %0#0 : tensor<1x1x2x8xf32> -} -// CHECK: func.func @main() -> tensor<1x1x2x8xf32> -// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<{{.*}}> : tensor<1x1x2x8xf32> -// CHECK: return %[[CONST]] : tensor<1x1x2x8xf32> - -// ----- +// TODO(b/330759552): Fix the msan issue and enable this test. +// func.func @fused_batchnorm_no_training() -> tensor<1x1x2x8xf32> { +// %cst_0 = "tf.Const"() {value = dense<[[[[0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2], [0.3, 0.4, 0.3, 0.4, 0.3, 0.4, 0.3, 0.4]]]]> : tensor<1x1x2x8xf32>} : () -> tensor<1x1x2x8xf32> +// %cst_1 = "tf.Const"() {value = dense<[0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2]> : tensor<8xf32>} : () -> tensor<8xf32> +// %cst_2 = "tf.Const"() {value = dense<[0.3, 0.4, 0.3, 0.4, 0.3, 0.4, 0.3, 0.4]> : tensor<8xf32>} : () -> tensor<8xf32> +// %0:6 = "tf.FusedBatchNormV3"(%cst_0, %cst_1, %cst_2, %cst_1, %cst_2) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<1x1x2x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<1x1x2x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) +// func.return %0#0 : tensor<1x1x2x8xf32> +// } +// COM: CHECK: func.func @main() -> tensor<1x1x2x8xf32> +// COM: CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<{{.*}}> : tensor<1x1x2x8xf32> +// COM: CHECK: return %[[CONST]] : tensor<1x1x2x8xf32> func.func @fused_batchnorm_no_training_arg_input(%arg_0: tensor<1x1x2x8xf32>) -> (tensor<1x1x2x8xf32>) { %cst_0 = "tf.Const"() {value = dense<[0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2]> : tensor<8xf32>} : () -> tensor<8xf32> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/defer_activation_transpose.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/defer_activation_transpose.mlir new file mode 100644 index 00000000000000..96b270f8b888f9 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/defer_activation_transpose.mlir @@ -0,0 +1,307 @@ +// RUN: stablehlo-quant-opt %s -stablehlo-defer-activation-transpose \ +// RUN: -split-input-file -verify-diagnostics | FileCheck %s + +// Tests that an `add(transpose(arg0), arg1)` pattern is converted to +// `transpose(add(arg0, transpose(arg1)))`. The transpose in the activation is +// deferred to the output of `stablehlo.add` and an extra transpose op is +// inserted to the RHS to match the shape of the operand. + +// CHECK-LABEL: add_with_activation_transpose +func.func @add_with_activation_transpose(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x4x3x3xf32> + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> + %2 = stablehlo.add %1, %0 : tensor<1x4x3x3xf32> + return %2 : tensor<1x4x3x3xf32> +} +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[CONST_0]], dims = [0, 2, 3, 1] : (tensor<1x4x3x3xf32>) -> tensor<1x3x3x4xf32> + +// Check that the shape of the add is changed to reflect the deferred transpose. +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[ARG_0]], %[[TRANSPOSE_0]] : tensor<1x3x3x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose +// CHECK: return %[[TRANSPOSE_1]] + +// ----- + +// Tests that an `add(transpose(arg0), broadcast_in_dim(arg1))` pattern is +// converted to `transpose(add(arg0, transpose(broadcast_in_dim(arg1))))`. +// The transpose in the activation is deferred to the output of `stablehlo.add` +// and an extra transpose op is inserted to the RHS to match the shape of the +// operand. + +// CHECK-LABEL: add_with_activation_transpose_broadcasted_rhs +func.func @add_with_activation_transpose_broadcasted_rhs(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> + %1 = stablehlo.broadcast_in_dim %0, dims = [1] : (tensor<4xf32>) -> tensor<1x4x3x3xf32> + %2 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> + %3 = stablehlo.add %2, %1 : tensor<1x4x3x3xf32> + return %3 : tensor<1x4x3x3xf32> +} +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant +// CHECK-DAG: %[[BROADCAST:.+]] = stablehlo.broadcast_in_dim %[[CONST_0]], dims = [1] : (tensor<4xf32>) -> tensor<1x4x3x3xf32> +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[BROADCAST]], dims = [0, 2, 3, 1] : (tensor<1x4x3x3xf32>) -> tensor<1x3x3x4xf32> + +// Check that the shape of the add is changed to reflect the deferred transpose. +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[ARG_0]], %[[TRANSPOSE_0]] : tensor<1x3x3x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose +// CHECK: return %[[TRANSPOSE_1]] + +// ----- + +// [No change] Tests that the activation transpose whose permutation is not +// `[0, 3, 1, 2]` is not deferred. + +// CHECK-LABEL: add_with_activation_transpose_permutation_mismatch +func.func @add_with_activation_transpose_permutation_mismatch( + %arg0: tensor<1x2x3x4xf32>) -> tensor<1x3x2x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x2x4xf32> + %1 = stablehlo.transpose %arg0, dims = [0, 2, 1, 3] : (tensor<1x2x3x4xf32>) -> tensor<1x3x2x4xf32> + %2 = stablehlo.add %1, %0 : tensor<1x3x2x4xf32> + return %2 : tensor<1x3x2x4xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[ADD_0]] + +// ----- + +// [No change] Tests that the activation transpose whose rank is not 4 is not +// deferred. + +// CHECK-LABEL: add_with_activation_transpose_rank_two +func.func @add_with_activation_transpose_rank_two(%arg0: tensor<1x2xf32>) -> tensor<2x1xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<2x1xf32> + %1 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32> + %2 = stablehlo.add %1, %0 : tensor<2x1xf32> + return %2 : tensor<2x1xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[ADD_0]] + +// ----- + +// [No change] Tests that the right-hand side that is not a constant is not +// deferred. + +// CHECK-LABEL: add_with_activation_transpose_nonconst_rhs +func.func @add_with_activation_transpose_nonconst_rhs(%arg0: tensor<1x3x3x4xf32>, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3x3xf32> { + %0 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> + %1 = stablehlo.add %0, %arg1 : tensor<1x4x3x3xf32> + return %1 : tensor<1x4x3x3xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[ADD_0]] + +// ----- + +// Tests that the transpose of the input of `stablehlo.reduce_window` is +// deferred to the result. The attributes are permutated according to the new +// input shape. + +// CHECK-LABEL: reduce_window_max_activation_transpose +func.func @reduce_window_max_activation_transpose(%arg0: tensor<1x16x16x4xf32>) -> tensor<1x4x8x8xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x16x16x4xf32>) -> tensor<1x4x16x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) {window_dimensions = array, window_strides = array} : (tensor<1x4x16x16xf32>, tensor) -> tensor<1x4x8x8xf32> + return %2 : tensor<1x4x8x8xf32> +} +// CHECK-SAME: %[[ARG:.+]]: tensor<1x16x16x4xf32> +// CHECK-DAG: %[[INIT_VALUE_CONST:.+]] = stablehlo.constant dense<0xFF800000> + +// Check that the body is not modified. +// CHECK: %[[REDUCE_WINDOW:.+]] = "stablehlo.reduce_window"(%[[ARG]], %[[INIT_VALUE_CONST]]) +// CHECK: ^bb0(%[[REDUCE_ARG_0:.+]]: tensor, %[[REDUCE_ARG_1:.+]]: tensor): +// CHECK: %[[MAX:.+]] = stablehlo.maximum %[[REDUCE_ARG_0]], %[[REDUCE_ARG_1]] +// CHECK: stablehlo.return %[[MAX]] + +// Check that the attributes window_dimensions & window_strides are also +// permutated to match the new input shape. +// CHECK: {window_dimensions = array, window_strides = array} +// CHECK-SAME: (tensor<1x16x16x4xf32>, tensor) -> tensor<1x8x8x4xf32> + +// Check that a `stablehlo.transpose` is added to the result to match the shape +// of the users. +// CHECK: %[[TRANSPOSE:.+]] = stablehlo.transpose %[[REDUCE_WINDOW]], dims = [0, 3, 1, 2] : (tensor<1x8x8x4xf32>) -> tensor<1x4x8x8xf32> +// CHECK: return %[[TRANSPOSE]] + +// ----- + +// Tests that the transpose of the input of `stablehlo.reduce_window` is +// deferred to the result. The attributes are permutated according to the new +// input shape. This test is similar to the test above with the difference that +// the `stablehlo.reduce_window` has explicit optional attributes: +// `base_dilations` and `window_dilations`. + +// CHECK-LABEL: reduce_window_max_activation_transpose_explicit_optional_attrs +func.func @reduce_window_max_activation_transpose_explicit_optional_attrs( + %arg0: tensor<1x16x16x4xf32>) -> tensor<1x4x15x15xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x16x16x4xf32>) -> tensor<1x4x16x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) { + window_dimensions = array, + window_strides = array, + base_dilations = array, + window_dilations = array + } : (tensor<1x4x16x16xf32>, tensor) -> tensor<1x4x15x15xf32> + return %2 : tensor<1x4x15x15xf32> +} +// CHECK-SAME: %[[ARG:.+]]: tensor<1x16x16x4xf32> +// CHECK-DAG: %[[INIT_VALUE_CONST:.+]] = stablehlo.constant dense<0xFF800000> + +// Check that the body is not modified. +// CHECK: %[[REDUCE_WINDOW:.+]] = "stablehlo.reduce_window"(%[[ARG]], %[[INIT_VALUE_CONST]]) +// CHECK: ^bb0(%[[REDUCE_ARG_0:.+]]: tensor, %[[REDUCE_ARG_1:.+]]: tensor): +// CHECK: %[[MAX:.+]] = stablehlo.maximum %[[REDUCE_ARG_0]], %[[REDUCE_ARG_1]] +// CHECK: stablehlo.return %[[MAX]] + +// Check that the attributes window_dimensions & window_strides along with +// optional attributes base_dilations and window_dilations are also permutated +// to match the new input shape. +// CHECK: {base_dilations = array, window_dilations = array, window_dimensions = array, window_strides = array} +// CHECK-SAME: (tensor<1x16x16x4xf32>, tensor) -> tensor<1x15x15x4xf32> + +// Check that a `stablehlo.transpose` is added to the result to match the shape +// of the users. +// CHECK: %[[TRANSPOSE:.+]] = stablehlo.transpose %[[REDUCE_WINDOW]], dims = [0, 3, 1, 2] : (tensor<1x15x15x4xf32>) -> tensor<1x4x15x15xf32> +// CHECK: return %[[TRANSPOSE]] + +// ----- + +// [No change] Tests that the transpose of the input of +// `stablehlo.reduce_window` is NOT deferred to the result, when the input +// tensor does not have rank 4. + +// CHECK-LABEL: reduce_window_max_activation_transpose +// CHECK-SAME: (%[[ARG:.+]]: tensor<16x8xf32>) -> tensor<4x8xf32> +func.func @reduce_window_max_activation_transpose_rank2(%arg0: tensor<16x8xf32>) -> tensor<4x8xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<16x8xf32>) -> tensor<8x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) {window_dimensions = array, window_strides = array} : (tensor<8x16xf32>, tensor) -> tensor<4x8xf32> + return %2 : tensor<4x8xf32> +} +// CHECK-DAG: stablehlo.constant +// CHECK: stablehlo.transpose %[[ARG]] +// CHECK: stablehlo.reduce_window + +// ----- + +// [No change] Tests that the transpose of the input of +// `stablehlo.reduce_window` is NOT deferred to the result, when it has an +// explicit `padding` attribute. + +// CHECK-LABEL: reduce_window_max_activation_transpose_with_padding +func.func @reduce_window_max_activation_transpose_with_padding(%arg0: tensor<1x16x16x4xf32>) -> tensor<1x4x9x9xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x16x16x4xf32>) -> tensor<1x4x16x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) { + window_dimensions = array, + window_strides = array, + padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64> + } : (tensor<1x4x16x16xf32>, tensor) -> tensor<1x4x9x9xf32> + return %2 : tensor<1x4x9x9xf32> +} +// CHECK-SAME: %[[ARG:.+]]: tensor<1x16x16x4xf32> +// CHECK-DAG: stablehlo.constant +// CHECK: stablehlo.transpose %[[ARG]] +// CHECK: stablehlo.reduce_window + +// ----- + +// [No change] Tests that the transpose of the input of +// `stablehlo.reduce_window` is NOT deferred to the result, when the transpose +// isn't `[0, 3, 1, 2]` (i.e. NCHW->NHWC). + +// CHECK-LABEL: reduce_window_max_activation_transpose_with_padding +func.func @reduce_window_max_activation_transpose_with_padding(%arg0: tensor<16x16x4x1xf32>) -> tensor<1x4x8x8xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<16x16x4x1xf32>) -> tensor<1x4x16x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) { + window_dimensions = array, + window_strides = array + } : (tensor<1x4x16x16xf32>, tensor) -> tensor<1x4x8x8xf32> + return %2 : tensor<1x4x8x8xf32> +} +// CHECK-SAME: %[[ARG:.+]]: tensor<16x16x4x1xf32> +// CHECK-DAG: stablehlo.constant +// CHECK: stablehlo.transpose %[[ARG]] +// CHECK: stablehlo.reduce_window + +// ----- + +// Tests that an `max(transpose(arg0), arg1)` pattern is converted to +// `transpose(max(arg0, transpose(arg1)))`. The transpose in the activation is +// deferred to the output of `stablehlo.max` and an extra transpose op is +// inserted to the RHS to match the shape of the operand. + +// CHECK-LABEL: max_with_activation_transpose +func.func @max_with_activation_transpose(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x4x3x3xf32> + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> + %2 = stablehlo.maximum %1, %0 : tensor<1x4x3x3xf32> + return %2 : tensor<1x4x3x3xf32> +} +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[CONST_0]], dims = [0, 2, 3, 1] : (tensor<1x4x3x3xf32>) -> tensor<1x3x3x4xf32> + +// Check that the shape of the add is changed to reflect the deferred transpose. +// CHECK: %[[MAX_0:.+]] = stablehlo.maximum %[[ARG_0]], %[[TRANSPOSE_0]] : tensor<1x3x3x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose +// CHECK: return %[[TRANSPOSE_1]] + +// ----- + +// [No change] Tests that the activation transpose of `stablehlo.maximum` whose +// permutation is not `[0, 3, 1, 2]` is not deferred. + +// CHECK-LABEL: max_with_activation_transpose_permutation_mismatch +func.func @max_with_activation_transpose_permutation_mismatch( + %arg0: tensor<1x2x3x4xf32>) -> tensor<1x3x2x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x2x4xf32> + %1 = stablehlo.transpose %arg0, dims = [0, 2, 1, 3] : (tensor<1x2x3x4xf32>) -> tensor<1x3x2x4xf32> + %2 = stablehlo.maximum %1, %0 : tensor<1x3x2x4xf32> + return %2 : tensor<1x3x2x4xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[MAX_0:.+]] = stablehlo.maximum %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[MAX_0]] + +// ----- + +// [No change] Tests that the activation transpose of `stablehlo.maximum` whose +// rank is not 4 is not deferred. + +// CHECK-LABEL: max_with_activation_transpose_rank_two +func.func @max_with_activation_transpose_rank_two(%arg0: tensor<1x2xf32>) -> tensor<2x1xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<2x1xf32> + %1 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32> + %2 = stablehlo.maximum %1, %0 : tensor<2x1xf32> + return %2 : tensor<2x1xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[MAX_0:.+]] = stablehlo.maximum %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[MAX_0]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize_hybrid.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir similarity index 98% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize_hybrid.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir index 9f68899873f0b0..89ff96efecf471 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize_hybrid.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir @@ -1,4 +1,4 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-prepare-quantize-hybrid | FileCheck %s +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-insert-weight-param | FileCheck %s // Test that q/dq pair is inserted between constant and XlaCallModule op // with quantizable trait and function name containing conv. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir index a0d797cfee4fa2..69bf09104c814d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir @@ -1,7 +1,78 @@ -// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs \ -// RUN: -split-input-file | FileCheck %s +// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs="quantization-specs=disable-all-dot-general" \ +// RUN: -split-input-file | FileCheck %s --check-prefix=DISABLE-ALL-DOT-GENERAL -// CHECK: @main +// Tests that `composite_dot_general_fn_1` and its corresponding XlaCallModuleOp +// contains attributes required for quantization, including the +// `_quantization_method` attribute that contains textpb of `Method`. + +// DISABLE-ALL-DOT-GENERAL: @main +func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + return %1 : tensor<1x1x64xf32> +} + +// DISABLE-ALL-DOT-GENERAL: %[[CONST:.+]] = stablehlo.constant dense<2.000000e+00> +// DISABLE-ALL-DOT-GENERAL: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) + +// Check that the `_quantization_method` attribute contains the quantization +// method in textproto format. The dot_general op quantization is explicitly +// disabled by having `_quantization_method = "no_quantization { }"`. +// DISABLE-ALL-DOT-GENERAL-SAME: _entry_function = @composite_dot_general_fn_1 +// DISABLE-ALL-DOT-GENERAL-SAME: _original_entry_function +// DISABLE-ALL-DOT-GENERAL-SAME: _quantization_method = "no_quantization { }" +// DISABLE-ALL-DOT-GENERAL-SAME: _tfl_quant_trait = "fully_quantizable" + +// DISABLE-ALL-DOT-GENERAL: return %[[XLA_CALL_MODULE:.+]] : tensor<1x1x64xf32> +// DISABLE-ALL-DOT-GENERAL: } + +// DISABLE-ALL-DOT-GENERAL-LABEL: private @composite_dot_general_fn_1 +// DISABLE-ALL-DOT-GENERAL-SAME: tf_quant.composite_function +// DISABLE-ALL-DOT-GENERAL: %[[DOT_GENERAL:.+]] = stablehlo.dot_general %arg0, %arg1 +// DISABLE-ALL-DOT-GENERAL: return %[[DOT_GENERAL:.+]] : tensor<1x1x64xf32> +// DISABLE-ALL-DOT-GENERAL: } + +// ----- + +// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs="quantization-specs=empty" \ +// RUN: -split-input-file | FileCheck %s --check-prefix=EMPTY + +// Tests that `composite_dot_general_fn_1` and its corresponding XlaCallModuleOp +// contains attributes required for quantization. `_quantization_method` is not +// set, as it is implicitly disabled. + +// EMPTY: @main +func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + return %1 : tensor<1x1x64xf32> +} + +// EMPTY: %[[CONST:.+]] = stablehlo.constant dense<2.000000e+00> +// EMPTY: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) + +// Check that the `_quantization_method` attribute doesn't contain the +// quantization method, implying "no_quantization". +// EMPTY-SAME: _entry_function = @composite_dot_general_fn_1 +// EMPTY-SAME: _original_entry_function +// EMPTY-NOT: _quantization_method +// EMPTY-SAME: _tfl_quant_trait = "fully_quantizable" + +// EMPTY: return %[[XLA_CALL_MODULE:.+]] : tensor<1x1x64xf32> +// EMPTY: } + +// EMPTY-LABEL: private @composite_dot_general_fn_1 +// EMPTY-SAME: tf_quant.composite_function +// EMPTY: %[[DOT_GENERAL:.+]] = stablehlo.dot_general %arg0, %arg1 +// EMPTY: return %[[DOT_GENERAL:.+]] : tensor<1x1x64xf32> +// EMPTY: } + +// ----- + +// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs="quantization-specs=static-range-ptq-to-all" \ +// RUN: -split-input-file | FileCheck %s --check-prefix=STATIC-RANGE-PTQ-TO-ALL + +// STATIC-RANGE-PTQ-TO-ALL: @main func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> @@ -11,21 +82,44 @@ func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { // contains attributes required for quantization, including the // `_quantization_method` attribute that contains textpb of `Method`. -// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> -// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// STATIC-RANGE-PTQ-TO-ALL: %[[CONST:.+]] = stablehlo.constant dense<2.000000e+00> +// STATIC-RANGE-PTQ-TO-ALL: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) + +// Check that the `_quantization_method` attribute contains the quantization +// method in textproto format, enabling static-range PTQ. +// STATIC-RANGE-PTQ-TO-ALL-SAME: _entry_function = @composite_dot_general_fn_1 +// STATIC-RANGE-PTQ-TO-ALL-SAME: _original_entry_function +// STATIC-RANGE-PTQ-TO-ALL-SAME: _quantization_method = "static_range_ptq { }" +// STATIC-RANGE-PTQ-TO-ALL-SAME: _tfl_quant_trait = "fully_quantizable" + +// STATIC-RANGE-PTQ-TO-ALL: return %[[XLA_CALL_MODULE:.+]] : tensor<1x1x64xf32> +// STATIC-RANGE-PTQ-TO-ALL: } + +// STATIC-RANGE-PTQ-TO-ALL-LABEL: private @composite_dot_general_fn_1 +// STATIC-RANGE-PTQ-TO-ALL-SAME: tf_quant.composite_function +// STATIC-RANGE-PTQ-TO-ALL: %[[DOT_GENERAL:.+]] = stablehlo.dot_general %arg0, %arg1 +// STATIC-RANGE-PTQ-TO-ALL: return %[[DOT_GENERAL:.+]] : tensor<1x1x64xf32> +// STATIC-RANGE-PTQ-TO-ALL: } + +// ----- + +// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs="quantization-specs=static-range-ptq-to-compute-heavy" \ +// RUN: -split-input-file | FileCheck %s --check-prefix=STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY + +// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: @main +func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { + %0 = stablehlo.add %arg0, %arg0 : tensor<1x2xf32> + return %0 : tensor<1x2xf32> +} +// Tests that `composite_add_fn_1` does not quantize when quantizing +// only compute-heavy ops. + +// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: %[[CONST:.+]] = stablehlo.constant dense<2.000000e+00> +// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%arg0, %arg0) // Check that the `_quantization_method` attribute contains the quantization -// method in textproto format. -// CHECK-SAME: _entry_function = @composite_dot_general_fn_1 -// CHECK-SAME: _original_entry_function -// CHECK-SAME: _quantization_method = "no_quantization {}" -// CHECK-SAME: _tfl_quant_trait = "fully_quantizable" - -// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> -// CHECK: } - -// CHECK-LABEL: private @composite_dot_general_fn_1 -// CHECK-SAME: tf_quant.composite_function -// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 -// CHECK: return %[[DOT_GENERAL:.*]] : tensor<1x1x64xf32> -// CHECK: } +// method in textproto format, enabling static-range PTQ. +// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: _entry_function = @composite_add_fn_1 +// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: _original_entry_function +// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY-NOT: _quantization_method +// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: _tfl_quant_trait = "fully_quantizable" diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/nchw_convolution_to_nhwc.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/nchw_convolution_to_nhwc.mlir index 6cdf9fdbf46b91..bdfce8cad3f5a8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/nchw_convolution_to_nhwc.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/nchw_convolution_to_nhwc.mlir @@ -75,8 +75,8 @@ func.func @conv_output_dim_numbers_mismatch(%arg0: tensor<1x8x4x4xf32>) -> tenso // Tests that a quantized convolution does not match. No conversion occurs. // CHECK-LABEL: quantized_convolution -func.func @quantized_convolution(%arg0: tensor<1x4x3x3x!quant.uniform>, %arg1: tensor<2x4x3x3x!quant.uniform>) -> tensor<1x2x3x3x!quant.uniform> { - %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x4x3x3x!quant.uniform>, tensor<2x4x3x3x!quant.uniform>) -> tensor<1x2x3x3x!quant.uniform> +func.func @quantized_convolution(%arg0: tensor<1x4x3x3x!quant.uniform>, %arg1: tensor<2x4x3x3x!quant.uniform>) -> tensor<1x2x3x3x!quant.uniform> { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x4x3x3x!quant.uniform>, tensor<2x4x3x3x!quant.uniform>) -> tensor<1x2x3x3x!quant.uniform> return %0 : tensor<1x2x3x3x!quant.uniform> } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize_per_channel.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize_per_channel.mlir index 9b3c6f0f0ae04f..1ff62b1170a6f5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize_per_channel.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize_per_channel.mlir @@ -17,8 +17,11 @@ module { // CHECK: "tf.XlaCallModule"(%[[dq_act]], %[[dq_weight]] %1 = "tf.XlaCallModule"(%0, %cst_0, %cst) { Sout = [#tf_type.shape<1x2x2x2>], config = "", - _entry_function = @composite_conv2d_with_bias_and_relu6_fn_10, module = "composite_conv2d_with_bias_and_relu6_fn_10", + _entry_function = @composite_conv2d_with_bias_and_relu6_fn_10, + // Represents a per-channel quantization for the operand index 1 with + // quantization dimension of 3 + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", platforms = [], version = 4 : i64 } : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<1x2x2x2xf32> %2 = "quantfork.stats"(%1) {layerStats = dense<[0.000000e+00, 6.000000e+00]> : tensor<2xf32>} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> @@ -90,3 +93,38 @@ module { return %0 : tensor<2x2xf32> } } + +// ----- + +// Tests that the `PrepareQuantizePass` prepares for per-tensor quantization for +// the weight of convolution. This is based on the `_quantization_method` that +// does not have a `input_quantized_types` with a specified `dimension_specs`. + +// CHECK-LABEL: conv_per_tensor_quantized_method +func.func private @conv_per_tensor_quantized_method(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> { + %cst = "tf.Const"() {device = "", value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<[[[[-6.30731344, 5.4962182], [1.80364347, -7.64542675], [-2.11145878, -7.08605719]], [[-9.54062747, -6.14013147], [6.12640238, -4.18223286], [5.05738974, 8.99269962]], [[3.3535192, 0.84816426], [-6.64676809, -7.95477629], [5.81315517, 9.21566581]]], [[[1.38622558, 4.63866329], [4.54742622, -1.43770897], [-3.96835279, 2.99996852]], [[0.989735424, -4.83384752], [-7.27702999, 1.17216611], [1.33735656, 0.728900194]], [[5.1286211, 8.98645591], [1.55008793, -3.85491467], [3.7003777, 9.26594448]]]]> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[1.27501142, 4.824783]> : tensor<2xf32>} : (tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst_0, %cst) { + Sout = [#tf_type.shape<1x2x2x2>], config = "", + module = "composite_conv_fn_1", + _entry_function = @composite_conv_fn_1, + _quantization_method = "static_range_ptq {}", + platforms = [], version = 4 : i64 + } : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<1x2x2x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[0.000000e+00, 6.000000e+00]> : tensor<2xf32>} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %2 : tensor<1x2x2x2xf32> +} +// CHECK-SAME: %[[ARG_0:.+]]: tensor<1x3x2x3xf32> + +// Test that the weight is prepared for per-tensor quantization, based on the +// `_quantization_method` attribute without a `dimension_specs` field in +// `QuantizedType`. +// CHECK-DAG: %[[WEIGHT_CONST:.+]] = stablehlo.constant {{.*}} tensor<2x3x3x2xf32> +// CHECK: %[[Q_WEIGHT_PER_TENSOR:.*]] = "quantfork.qcast"(%[[WEIGHT_CONST]]) {{.*}} (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> +// CHECK: %[[DQ_WEIGHT:.*]] = "quantfork.dcast"(%[[Q_WEIGHT_PER_TENSOR]]) + +// CHECK: %[[Q_ACTIVATION:.*]] = "quantfork.qcast"(%[[ARG_0]]) +// CHECK-SAME: -> tensor<1x3x2x3x!quant.uniform> +// CHECK: %[[DQ_ACTIVATION:.*]] = "quantfork.dcast"(%[[Q_ACTIVATION]]) +// CHECK: "tf.XlaCallModule"(%[[DQ_ACTIVATION]], %[[DQ_WEIGHT]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_hybrid.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_hybrid.mlir deleted file mode 100644 index f9a6aaea3a500f..00000000000000 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_hybrid.mlir +++ /dev/null @@ -1,31 +0,0 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-quantize=enable-weight-only=true | FileCheck %s - -// Test that hybrid quantized op is produced when q/dq pair only exists for weight. - -module attributes {tf_saved_model.semantics} { - func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> - %0 = "quantfork.qcast"(%cst) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> - %1 = "quantfork.dcast"(%0) : (tensor<2x3x!quant.uniform>) -> tensor<2x3xf32> - %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %2 : tensor<1x3xf32> - } - - func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %0 : tensor<1x3xf32> - } -} - -// CHECK-LABEL: quantize_dot_general_fn -// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> -// CHECK: %[[CST:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> -// CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> -// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[Q]]) : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> -// CHECK: return %[[CALL]] - -// CHECK: quantized_dot_general_fn -// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> -// CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] -// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> -// CHECK: return %[[DOT]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir new file mode 100644 index 00000000000000..6db474de676ccc --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir @@ -0,0 +1,65 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-quantize=enable-weight-only=true | FileCheck %s + +// Test that hybrid quantized dot_general is produced when q/dq pair only exists +// for weight. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> + %0 = "quantfork.qcast"(%cst) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> + %1 = "quantfork.dcast"(%0) : (tensor<2x3x!quant.uniform>) -> tensor<2x3xf32> + %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// CHECK-LABEL: quantize_dot_general_fn +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> +// CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[Q]]) : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_dot_general_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: return %[[DOT]] + +// ----- + +// Test that hybrid quantized convolution is produced when q/dq pair only exists +// for weight. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> + %0 = "quantfork.qcast"(%cst) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> + %1 = "quantfork.dcast"(%0) : (tensor<2x3x3x2x!quant.uniform>) -> tensor<2x3x3x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +} + +// CHECK-LABEL: quantize_conv_fn +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x4x3xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> +// CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[Q]]) : (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_conv_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x3x4x3xf32>, %[[ARG2:.+]]: tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG1]], %[[ARG2]]) +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CONV]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir index 13570eb583110e..f9fa9ce5f60b87 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir @@ -187,13 +187,31 @@ module attributes {tf_saved_model.semantics} { // ----- -// Tests that basic convolution is properly quantized. +// Tests that basic convolution is properly quantized. It is per-channel +// quantized unless `enable-per-channel-quantized-weight=false`, according to +// `_quantization_method` with an `input_quantized_types` and explicit +// `dimension_specs`. module attributes {tf_saved_model.semantics} { func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64, _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) { + Sout = [#tf_type.shape<1x3x4x2>], + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64, + _entry_function = @composite_conv_fn, + _original_entry_function = "composite_conv_fn", + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _stablehlo_module_attrs = {}, + _tfl_quant_trait = "fully_quantizable", + device = "" + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> return %2 : tensor<1x3x4x2xf32> } @@ -235,6 +253,58 @@ module attributes {tf_saved_model.semantics} { // ----- +// Tests that basic convolution is properly quantized. In this example, the +// convolution is always per-tensor quantized (even if +// enable-per-channel-quantized-weights=true), according to +// `_quantization_method`. + +// CHECK-LABEL: quantize_conv_fn_per_tensor +func.func @quantize_conv_fn_per_tensor(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst) { + Sout = [#tf_type.shape<1x3x4x2>], + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64, + _entry_function = @composite_conv_fn, + _original_entry_function = "composite_conv_fn", + _quantization_method = "static_range_ptq {}", + _stablehlo_module_attrs = {}, + _tfl_quant_trait = "fully_quantizable", + device = "" + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> +} +// Check that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. + +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> +} +// Checks that the entry function is quantized for convolution. Quantized +// convolution outputs an i32 quantized tensor, followed by requantization to +// i8 quantized tensor. + +// CHECK: func.func private @quantized_conv_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// ----- + // Tests that fused pattern for convolution + bias is properly quantized. // Checks that fused functions with 1D bias is properly quantized. @@ -246,7 +316,22 @@ module attributes {tf_saved_model.semantics} { %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<2xf32>} : () -> tensor<2xf32> %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_1d_fn, _original_entry_function = "composite_conv_with_bias_1d_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_1d_fn, + _original_entry_function = "composite_conv_with_bias_1d_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32> %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> return %2 : tensor<1x3x4x2xf32> } @@ -298,7 +383,22 @@ module attributes {tf_saved_model.semantics} { %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_fn, _original_entry_function = "composite_conv_with_bias_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x4x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_fn, + _original_entry_function = "composite_conv_with_bias_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x4x2xf32> %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> return %2 : tensor<1x3x4x2xf32> } @@ -349,7 +449,22 @@ module attributes {tf_saved_model.semantics} { %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_dynamic_fn, _original_entry_function = "composite_conv_with_bias_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_dynamic_fn, + _original_entry_function = "composite_conv_with_bias_dynamic_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor return %2 : tensor } @@ -426,7 +541,22 @@ module attributes {tf_saved_model.semantics} { %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_and_relu_dynamic_fn, _original_entry_function = "composite_conv_with_bias_and_relu_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_and_relu_dynamic_fn, + _original_entry_function = "composite_conv_with_bias_and_relu_dynamic_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor %2 = "quantfork.stats"(%1) {layerStats = dense<[0.00000000e-6, 8.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor return %2 : tensor } @@ -506,7 +636,22 @@ module attributes {tf_saved_model.semantics} { %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_and_relu6_dynamic_fn, _original_entry_function = "composite_conv_with_bias_and_relu6_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_and_relu6_dynamic_fn, + _original_entry_function = "composite_conv_with_bias_and_relu6_dynamic_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 6.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor return %2 : tensor } @@ -598,7 +743,7 @@ module attributes {tf_saved_model.semantics} { // ----- -// Tests that basic gather is properly quantized. +// Tests that basic `stablehlo.gather` is properly quantized. module attributes {tf_saved_model.semantics} { // CHECK: func.func private @quantize_gather_fn(%[[ARG:.+]]: tensor<3x4x2xf32>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} @@ -631,6 +776,5 @@ module attributes {tf_saved_model.semantics} { return %0 : tensor<2x3x2x2xf32> } // CHECK: %[[GATHER:.+]] = "stablehlo.gather"(%[[ARG_0]], %[[ARG_1]]) {{.*}} : (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> -// CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[GATHER]] : tensor<2x3x2x2x!quant.uniform> -// CHECK: return %[[UNIFORM_QUANTIZE]] : tensor<2x3x2x2x!quant.uniform> +// CHECK: return %[[GATHER]] : tensor<2x3x2x2x!quant.uniform> } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_all_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_all_ops.mlir new file mode 100644 index 00000000000000..72851d92b64b75 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_all_ops.mlir @@ -0,0 +1,46 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -stablehlo-quantize-composite-functions=enable-full-int-quantization=true | FileCheck --check-prefix=CHECK-FULL-INT %s + +// Tests that a basic `stablehlo.add` and a fused `stablehlo.dot_general` +// are properly quantized. + +module attributes {tf_saved_model.semantics} { +// CHECK-FULL-INT: func.func private @quantize_add_fn(%[[ARG:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} + func.func private @quantize_add_fn(%arg: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst_0 = "tf.Const"() {value = dense<1.00000000e-1> : tensor<1x2xf32>} : () -> tensor<1x2xf32> + %cst_1 = "tf.Const"() {value = dense<1.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantfork.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst_0) {Sout = [#tf_type.shape<1x2>], _entry_function = @composite_add_fn, _original_entry_function = "composite_add_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %3 = "quantfork.stats"(%2) {layerStats = dense<[5.00000000e-6, 6.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %4 = "tf.XlaCallModule"(%3, %cst_1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantfork.stats"(%4) {layerStats = dense<[5.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %5 : tensor<1x3xf32> + } +// CHECK-FULL-INT: %[[CONST:.+]] = stablehlo.constant() {value = dense<127> : tensor<1x2xi8>} : () -> tensor<1x2x!quant.uniform> +// CHECK-FULL-INT: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}>> +// CHECK-FULL-INT: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-FULL-INT: %[[CALL:.+]] = call @quantized_add_fn(%[[UNIFORM_QUANTIZE]], %[[CONST]]) : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK-FULL-INT: %[[UNIFORM_DEQUANTIZE:.+]] = stablehlo.uniform_dequantize %[[CALL]] : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// CHECK-FULL-INT: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[UNIFORM_DEQUANTIZE]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-FULL-INT: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK-FULL-INT: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK-FULL-INT: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + +// CHECK-FULL-INT: func.func private @quantized_add_fn(%[[ARG_0:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_add_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.add %arg0, %arg1 : tensor<1x2xf32> + return %0 : tensor<1x2xf32> + } +// CHECK-FULL-INT: %[[ADD:.+]] = stablehlo.add %arg0, %arg1 : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK-FULL-INT: return %[[ADD]] : tensor<1x2x!quant.uniform> + +// CHECK-FULL-INT: func.func private @quantized_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// CHECK-FULL-INT: %[[DOT_GENERAL:.+]] = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1,{{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK-FULL-INT: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-FULL-INT: return %[[UNIFORM_QUANTIZE]] : tensor<1x3x!quant.uniform> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_hybrid.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_hybrid.mlir deleted file mode 100644 index aa42045251778c..00000000000000 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_hybrid.mlir +++ /dev/null @@ -1,30 +0,0 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ -// RUN: -stablehlo-quantize-composite-functions=enable-weight-only=true | FileCheck --check-prefix=CHECK %s - -// Test that hybrid quantized dot_general op is produced when hybrid-quantize -// is set to true. - -module attributes {tf_saved_model.semantics} { - func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> - %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %1 : tensor<1x3xf32> - } - - func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %0 : tensor<1x3xf32> - } -} - -// CHECK-LABEL: quantize_dot_general_fn -// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> -// CHECK: %[[CST:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform> -// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[CST]]) : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> -// CHECK: return %[[CALL]] - -// CHECK: quantized_dot_general_fn -// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> -// CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] -// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> -// CHECK: return %[[DOT]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir new file mode 100644 index 00000000000000..dce15fe07760e2 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir @@ -0,0 +1,60 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -stablehlo-quantize-composite-functions=enable-weight-only=true | FileCheck --check-prefix=CHECK %s + +// Test that weight-only quantized dot_general op is produced when +// enable-weight-only is set to true. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// CHECK-LABEL: quantize_dot_general_fn +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[CST]]) : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_dot_general_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: return %[[DOT]] + +// ----- + +// Test that hybrid quantized convolution op is produced when enable-weight-only +// is set to true. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %1 : tensor<1x3x4x2xf32> + } + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +} + +// CHECK-LABEL: quantize_conv_fn +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x4x3xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[CST]]) : (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_conv_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x3x4x3xf32>, %[[ARG2:.+]]: tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG1]], %[[ARG2]]) +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CONV]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/pipelines/process_nchw_tensor.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/pipelines/process_nchw_tensor.mlir new file mode 100644 index 00000000000000..831131a4c64555 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/pipelines/process_nchw_tensor.mlir @@ -0,0 +1,171 @@ +// RUN: stablehlo-quant-opt %s -stablehlo-process-nchw-tensor \ +// RUN: -split-input-file -verify-diagnostics | FileCheck %s + +// Tests that a `convolution(%activation, %weight)` with the activation tensor +// NCHW format is converted to NHWC convolution. Transpose ops are inserted to +// the activation and output to match the function signature. The weight +// constant is transposed. + +// CHECK-LABEL: nchw_conv +// CHECK-SAME: %[[ARG:.+]]: tensor<1x8x4x4xf32> +func.func @nchw_conv(%arg0: tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32> { + %0 = stablehlo.constant() {value = dense<7.000000e+00> : tensor<8x8x3x3xf32>} : () -> tensor<8x8x3x3xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x8x4x4xf32>, tensor<8x8x3x3xf32>) -> tensor<1x8x4x4xf32> + return %2 : tensor<1x8x4x4xf32> +} +// CHECK-DAG: %[[CONST:.+]] = stablehlo.constant {{.*}} : tensor<3x3x8x8xf32> +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[ARG]], dims = [0, 2, 3, 1] : (tensor<1x8x4x4xf32>) -> tensor<1x4x4x8xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[TRANSPOSE_0]], %[[CONST]]) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = {{\[\[}}1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x4x4x8xf32>, tensor<3x3x8x8xf32>) -> tensor<1x4x4x8xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose %[[CONV]], dims = [0, 3, 1, 2] : (tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> +// CHECK: return %[[TRANSPOSE_1]] + +// ----- + +// Tests that a `add(convolution(%activation, %weight), %bias)` with the +// activation tensor of NCHW format is converted to NHWC convolution + add +// operation. Transpose ops are inserted to activations and outputs to match the +// function signature. Constants are also transposed accordingly. + +// CHECK-LABEL: nchw_conv_with_bias_add +// CHECK-SAME: %[[ARG:.+]]: tensor<1x2x5x5xf32> +func.func @nchw_conv_with_bias_add(%arg0: tensor<1x2x5x5xf32>) -> tensor<1x4x5x5xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<4x2x3x3xf32> + %1 = stablehlo.constant dense<3.000000e+00> : tensor<1x4x5x5xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x5x5xf32>, tensor<4x2x3x3xf32>) -> tensor<1x4x5x5xf32> + %3 = stablehlo.add %2, %1 : tensor<1x4x5x5xf32> + return %3 : tensor<1x4x5x5xf32> +} +// CHECK-DAG: %[[WEIGHT_CONST:.+]] = stablehlo.constant {{.*}} : tensor<3x3x2x4xf32> +// CHECK-DAG: %[[BIAS_CONST:.+]] = stablehlo.constant {{.*}} : tensor<1x5x5x4xf32> +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[ARG]], dims = [0, 2, 3, 1] : (tensor<1x2x5x5xf32>) -> tensor<1x5x5x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[TRANSPOSE_0]], %[[WEIGHT_CONST]]) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = {{\[\[}}1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x5x5x2xf32>, tensor<3x3x2x4xf32>) -> tensor<1x5x5x4xf32> +// CHECK: %[[ADD:.+]] = stablehlo.add %[[CONV]], %[[BIAS_CONST]] : tensor<1x5x5x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose %[[ADD]], dims = [0, 3, 1, 2] : (tensor<1x5x5x4xf32>) -> tensor<1x4x5x5xf32> +// CHECK: return %[[TRANSPOSE_1]] + +// ----- + +// Tests that a `add(convolution(%activation, %weight), %bias)` pattern with the +// activation tensor of NCHW format and non-constant bias is converted to NHWC +// convolution, but without the deferred transpose for `stablehlo.add`. +// Transpose ops are inserted to the activation and output of +// `stablehlo.convolution`. The weight constants is transposed. + +// CHECK-LABEL: nchw_conv_with_nonconst_bias_add +// CHECK-SAME: %[[ARG_0:.+]]: tensor<1x2x5x5xf32> +// CHECK-SAME: %[[ARG_1:.+]]: tensor<1x4x5x5xf32> +func.func @nchw_conv_with_nonconst_bias_add(%arg0: tensor<1x2x5x5xf32>, %arg1: tensor<1x4x5x5xf32>) -> tensor<1x4x5x5xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<4x2x3x3xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x5x5xf32>, tensor<4x2x3x3xf32>) -> tensor<1x4x5x5xf32> + %2 = stablehlo.add %1, %arg1 : tensor<1x4x5x5xf32> + return %2 : tensor<1x4x5x5xf32> +} +// CHECK-DAG: %[[WEIGHT_CONST:.+]] = stablehlo.constant {{.*}} : tensor<3x3x2x4xf32> +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[ARG_0]], dims = [0, 2, 3, 1] : (tensor<1x2x5x5xf32>) -> tensor<1x5x5x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[TRANSPOSE_0]], %[[WEIGHT_CONST]]) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = {{\[\[}}1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x5x5x2xf32>, tensor<3x3x2x4xf32>) -> tensor<1x5x5x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose %[[CONV]], dims = [0, 3, 1, 2] : (tensor<1x5x5x4xf32>) -> tensor<1x4x5x5xf32> +// CHECK: %[[ADD:.+]] = stablehlo.add %[[TRANSPOSE_1]], %[[ARG_1]] : tensor<1x4x5x5xf32> +// CHECK: return %[[ADD]] + +// ----- + +// Tests that a `reduce_window{max}(add(convolution(%activation, %weight), %bias), %init_value)` +// with the activation tensor of NCHW format is converted to NHWC convolution + +// add + reduce_window (with max) operation. Transpose ops are inserted to +// activation and the final result to match the function signature. Constants +// are also transposed accordingly. + +// CHECK-LABEL: nchw_conv_with_bias_add_max_pool +// CHECK-SAME: %[[ARG:.+]]: tensor<1x2x5x5xf32> +func.func @nchw_conv_with_bias_add_max_pool(%arg0: tensor<1x2x5x5xf32>) -> tensor<1x4x2x2xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<4x2x3x3xf32> + %1 = stablehlo.constant dense<3.000000e+00> : tensor<1x4x5x5xf32> + %5 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x5x5xf32>, tensor<4x2x3x3xf32>) -> tensor<1x4x5x5xf32> + %3 = stablehlo.add %2, %1 : tensor<1x4x5x5xf32> + %4 = "stablehlo.reduce_window"(%3, %5) ({ // max pool + ^bb0(%arg1: tensor, %arg2: tensor): + %6 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %6 : tensor + }) { + window_dimensions = array, + window_strides = array + } : (tensor<1x4x5x5xf32>, tensor) -> tensor<1x4x2x2xf32> + return %4 : tensor<1x4x2x2xf32> +} +// CHECK-DAG: %[[WEIGHT_CONST:.+]] = stablehlo.constant {{.*}} : tensor<3x3x2x4xf32> +// CHECK-DAG: %[[BIAS_CONST:.+]] = stablehlo.constant {{.*}} : tensor<1x5x5x4xf32> +// CHECK-DAG: %[[INIT_VALUE_CONST:.+]] = stablehlo.constant dense<0xFF800000> : tensor +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[ARG]], dims = [0, 2, 3, 1] : (tensor<1x2x5x5xf32>) -> tensor<1x5x5x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[TRANSPOSE_0]], %[[WEIGHT_CONST]]) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = {{\[\[}}1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x5x5x2xf32>, tensor<3x3x2x4xf32>) -> tensor<1x5x5x4xf32> +// CHECK: %[[ADD:.+]] = stablehlo.add %[[CONV]], %[[BIAS_CONST]] : tensor<1x5x5x4xf32> +// CHECK: %[[REDUCE_WINDOW_MAX:.+]] = "stablehlo.reduce_window"(%[[ADD]], %[[INIT_VALUE_CONST:.+]]) +// CHECK: stablehlo.maximum +// CHECK: {window_dimensions = array, window_strides = array} : (tensor<1x5x5x4xf32>, tensor) -> tensor<1x2x2x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose %[[REDUCE_WINDOW_MAX]], dims = [0, 3, 1, 2] : (tensor<1x2x2x4xf32>) -> tensor<1x4x2x2xf32> +// CHECK: return %[[TRANSPOSE_1]] + +// ----- + +// Tests that a `maximum(add(convolution(%activation, %weight), %bias), %zero)` +// with the activation tensor of NCHW format is converted to NHWC convolution + +// add + maximum operation. Transpose ops are inserted to the activation and the +// final output to match the function signature. Constants are also transpose- +// folded accordingly. + +// CHECK-LABEL: nchw_conv_with_bias_add_relu +// CHECK-SAME: %[[ARG:.+]]: tensor<1x2x5x5xf32> +func.func @nchw_conv_with_bias_add_relu(%arg0: tensor<1x2x5x5xf32>) -> tensor<1x4x5x5xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<4x2x3x3xf32> + %5 = stablehlo.constant dense<0.000000e+00> : tensor<1x4x5x5xf32> + %1 = stablehlo.constant dense<3.000000e+00> : tensor<1x4x5x5xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x5x5xf32>, tensor<4x2x3x3xf32>) -> tensor<1x4x5x5xf32> + %3 = stablehlo.add %2, %1 : tensor<1x4x5x5xf32> + %4 = stablehlo.maximum %3, %5 : tensor<1x4x5x5xf32> + return %4 : tensor<1x4x5x5xf32> +} +// CHECK-DAG: %[[WEIGHT_CONST:.+]] = stablehlo.constant {{.*}} : tensor<3x3x2x4xf32> +// CHECK-DAG: %[[ZERO_CONST:.+]] = stablehlo.constant {{.*}} : tensor<1x5x5x4xf32> +// CHECK-DAG: %[[BIAS_CONST:.+]] = stablehlo.constant {{.*}} : tensor<1x5x5x4xf32> +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[ARG]], dims = [0, 2, 3, 1] : (tensor<1x2x5x5xf32>) -> tensor<1x5x5x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[TRANSPOSE_0]], %[[WEIGHT_CONST]]) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = {{\[\[}}1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x5x5x2xf32>, tensor<3x3x2x4xf32>) -> tensor<1x5x5x4xf32> +// CHECK: %[[ADD:.+]] = stablehlo.add %[[CONV]], %[[BIAS_CONST]] : tensor<1x5x5x4xf32> +// CHECK: %[[MAX:.+]] = stablehlo.maximum %[[ADD]], %[[ZERO_CONST]] : tensor<1x5x5x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose %[[MAX]], dims = [0, 3, 1, 2] : (tensor<1x5x5x4xf32>) -> tensor<1x4x5x5xf32> +// CHECK: return %[[TRANSPOSE_1]] + +// ----- + +// Tests that a `maximum(add(convolution(%activation, %weight), broadcast(%bias) +// ), %zero)` with the activation tensor of NCHW format is converted to NHWC +// convolution + add + maximum operation. Transpose ops are inserted to the +// first activation, final output, and the bias constant (after the broadcast), +// to match the function signature. Constants are also transpose-folded +// accordingly. +// +// Note that the `transpose` after the `broadcast_in_dim` is not folded by the +// `FoldConstantTransposePass`. + +// CHECK-LABEL: nchw_conv_with_broadcasted_bias_add_relu +// CHECK-SAME: %[[ARG:.+]]: tensor<1x2x5x5xf32> +func.func @nchw_conv_with_broadcasted_bias_add_relu(%arg0: tensor<1x2x5x5xf32>) -> tensor<1x4x5x5xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<4x2x3x3xf32> // weight + %1 = stablehlo.constant dense<3.000000e+00> : tensor<4xf32> // bias + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x4x5x5xf32> // relu + %3 = stablehlo.broadcast_in_dim %1, dims = [1] : (tensor<4xf32>) -> tensor<1x4x5x5xf32> + %4 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x5x5xf32>, tensor<4x2x3x3xf32>) -> tensor<1x4x5x5xf32> + %5 = stablehlo.add %4, %3 : tensor<1x4x5x5xf32> + %6 = stablehlo.maximum %5, %2 : tensor<1x4x5x5xf32> + return %6 : tensor<1x4x5x5xf32> +} +// CHECK-DAG: %[[WEIGHT_CONST:.+]] = stablehlo.constant {{.*}} : tensor<3x3x2x4xf32> +// CHECK-DAG: %[[ZERO_CONST:.+]] = stablehlo.constant {{.*}} : tensor<1x5x5x4xf32> +// CHECK-DAG: %[[BIAS_CONST:.+]] = stablehlo.constant {{.*}} : tensor<4xf32> +// CHECK-DAG: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %[[BIAS_CONST]], dims = [1] : (tensor<4xf32>) -> tensor<1x4x5x5xf32> +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[ARG]], dims = [0, 2, 3, 1] : (tensor<1x2x5x5xf32>) -> tensor<1x5x5x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[TRANSPOSE_0]], %[[WEIGHT_CONST]]) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = {{\[\[}}1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x5x5x2xf32>, tensor<3x3x2x4xf32>) -> tensor<1x5x5x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose %[[BROADCAST_IN_DIM]], dims = [0, 2, 3, 1] : (tensor<1x4x5x5xf32>) -> tensor<1x5x5x4xf32> +// CHECK: %[[ADD:.+]] = stablehlo.add %[[CONV]], %[[TRANSPOSE_1]] : tensor<1x5x5x4xf32> +// CHECK: %[[MAX:.+]] = stablehlo.maximum %[[ADD]], %[[ZERO_CONST]] : tensor<1x5x5x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose %[[MAX]], dims = [0, 3, 1, 2] : (tensor<1x5x5x4xf32>) -> tensor<1x4x5x5xf32> +// CHECK: return %[[TRANSPOSE_1]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc b/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc index 69d9a725a37ebe..9b587e4273965f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc @@ -29,6 +29,7 @@ limitations under the License. #include "stablehlo/transforms/Passes.h" // from @stablehlo #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h" @@ -53,6 +54,9 @@ int main(int argc, char** argv) { // These passes are only used for testing purposes. mlir::quant::stablehlo::testing::registerTestPasses(); + // Register StableHLO Quantizer pass pipelines. + mlir::quant::stablehlo::RegisterPassPipelines(); + mlir::DialectRegistry registry; registry.insert +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo namespace mlir::quant::stablehlo { @@ -30,19 +32,25 @@ class StablehloTypeUtilsTest : public Test { protected: StablehloTypeUtilsTest() { ctx_.loadDialect(); + mlir::arith::ArithDialect, mlir::func::FuncDialect>(); } MLIRContext ctx_; OpBuilder builder_{&ctx_}; }; -TEST_F(StablehloTypeUtilsTest, ValidStablehloOpSucceeds) { - mlir::stablehlo::ConstantOp constant_op = +TEST_F(StablehloTypeUtilsTest, IsStablehloOpSucceedsWithStablehloOp) { + const OwningOpRef constant_op = builder_.create( builder_.getUnknownLoc(), builder_.getI32IntegerAttr(0)); - EXPECT_TRUE(IsStablehloOp(constant_op)); - constant_op->erase(); + EXPECT_TRUE(IsStablehloOp(*constant_op)); +} + +TEST_F(StablehloTypeUtilsTest, IsStablehloOpFailsWithArithOp) { + const OwningOpRef constant_op = + builder_.create(builder_.getUnknownLoc(), + builder_.getI32IntegerAttr(0)); + EXPECT_FALSE(IsStablehloOp(*constant_op)); } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index 6ef72d68c8ea83..be0792ab76aff3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -1,7 +1,7 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("//tensorflow:strict.default.bzl", "py_strict_binary") # Placeholder: load py_proto_library -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist.bzl", "internal_visibility_allowlist") @@ -301,6 +301,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_traits", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "@llvm-project//llvm:Support", + "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", @@ -406,6 +407,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", "//tensorflow/compiler/mlir/quantization/common:func", "//tensorflow/compiler/mlir/quantization/common:lift_as_function_call", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", @@ -442,9 +444,11 @@ cc_library( "//tensorflow/lite/kernels/internal:quantization_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/random", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_googlesource_code_re2//:re2", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD index 62a6f27c8ad5f1..23ce2105634854 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD @@ -183,6 +183,7 @@ tf_cc_test( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD index 1734fa03aefe3e..de23418e1af031 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD @@ -51,7 +51,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@llvm-project//llvm:Support", + "@llvm-project//mlir:Dialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:QuantOps", @@ -85,8 +85,8 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/quantization/common/quantization_lib", - "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc index 8a9dc4eb3d4989..52ca3722a12bd5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc @@ -20,10 +20,12 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.h b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.h index 5d25779826e81c..bc6031eea7d85b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.h @@ -28,6 +28,7 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.cc index c86968b319c6dd..afeb8905855837 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.cc @@ -16,7 +16,9 @@ limitations under the License. #include -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir::quant { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc index 00ee53b84647eb..239fe32946ab87 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc @@ -182,7 +182,7 @@ class AddDumpTensorOp : public OpRewritePattern { rewriter.getNamedAttr("file_name", rewriter.getStringAttr(file_name)), // The op is disabled by default. Otherwise, values will be saved // during calibration. - rewriter.getNamedAttr("enabled", rewriter.getBoolAttr(false)), + rewriter.getNamedAttr("enabled", rewriter.getBoolAttr(enabled)), rewriter.getNamedAttr("func_name", rewriter.getStringAttr(func_name)), rewriter.getNamedAttr("node_name", rewriter.getStringAttr(node_name)), }; @@ -246,7 +246,7 @@ class AddDumpTensorOp : public OpRewritePattern { // Attach DumpTensorOp to its output layer. SmallVector dump_attributes = CreateDumpAttributes(rewriter, folder_name, file_name, - /*enabled=*/false, func_name, node_name); + /*enabled=*/true, func_name, node_name); rewriter.create(op->getLoc(), TypeRange{}, result, dump_attributes); @@ -261,7 +261,7 @@ class AddDumpTensorOp : public OpRewritePattern { // Attach second DumpTensorOp to its output unquantized layer. SmallVector dump_attributes = CreateDumpAttributes( rewriter, folder_name, /*file_name=*/"unquantized_tensor_data.pb", - /*enabled=*/false, func_name, node_name); + /*enabled=*/true, func_name, node_name); rewriter.create(op.getLoc(), TypeRange{}, new_op->getResult(0), dump_attributes); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/duplicate_shape_determining_constants.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/duplicate_shape_determining_constants.cc index 5237102335e5df..8590a00775cdf0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/duplicate_shape_determining_constants.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/duplicate_shape_determining_constants.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include "absl/algorithm/container.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -26,9 +26,11 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // Required to use LLVM_DEBUG macro. #define DEBUG_TYPE "quant-duplicate-shape-determining-constants" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc index e518826d7e6d12..56b9d7393aacfd 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -32,6 +33,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" @@ -45,6 +47,7 @@ namespace quant { namespace { using ::stablehlo::quantization::CalibrationOptions; +using ::stablehlo::quantization::Method; constexpr StringRef kQuantTraitAttrName = "_tfl_quant_trait"; @@ -199,7 +202,7 @@ class AddCustomAggregationOp : public RewritePattern { // The CustomAggregatorOp is only added after quantizable values. SmallVector quantizable_values; - if (isCallToLiftedFunction(op)) { + if (IsCallToQuantizableLiftedFunction(op)) { // Quantize inputs of quantizable composite functions. for (Value input : op->getOperands()) { Type element_type = getElementTypeOrSelf(input.getType()); @@ -226,7 +229,7 @@ class AddCustomAggregationOp : public RewritePattern { // Quantize output of fully quantizable composite functions. for (Value input : op->getOperands()) { auto defining_op = input.getDefiningOp(); - if (!isCallToLiftedFunction(defining_op)) { + if (!IsCallToQuantizableLiftedFunction(defining_op)) { continue; } @@ -282,9 +285,13 @@ class AddCustomAggregationOp : public RewritePattern { CalibrationOptions calib_opts_; // Whether the op is a call op to lifted composite function. - bool isCallToLiftedFunction(Operation *op) const { + bool IsCallToQuantizableLiftedFunction(Operation *op) const { if (!op) return false; - if (isa(op)) return true; + if (auto xla_call_module_op = dyn_cast_or_null(op); + xla_call_module_op != nullptr) { + absl::StatusOr method = GetQuantizationMethod(xla_call_module_op); + if (method.ok() && method->has_static_range_ptq()) return true; + } TF::PartitionedCallOp call_op = dyn_cast_or_null(op); return call_op && call_op->hasAttrOfType(kQuantTraitAttrName) && diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_quantized_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_quantized_functions.cc index 20ffa5aa9b793c..47ab3b82fc2f24 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_quantized_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_quantized_functions.cc @@ -19,13 +19,17 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/SourceMgr.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library.h" @@ -38,6 +42,7 @@ namespace quant { namespace { using QuantMethod = tensorflow::quantization::QuantizationMethod::PresetMethod; +using ::tensorflow::quantization::OpSet; class InsertQuantizedFunctionsPass : public PassWrapper #include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" @@ -39,6 +39,7 @@ namespace { using QuantMethod = ::tensorflow::quantization::QuantizationMethod::PresetMethod; +using ::tensorflow::quantization::OpSet; class LiftQuantizableSpotsAsFunctionsDRQPass : public PassWrapper -#include #include +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" // IWYU pragma: keep - required to use `IsSplatValueEqual`. #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h index 97a383631e70db..5ea5a058cc94d3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h @@ -20,11 +20,11 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" @@ -49,14 +49,14 @@ CreateLiftQuantizableSpotsAsFunctionsPass( // Apply graph optimizations such as fusing and constant folding to prepare // lifting. std::unique_ptr> CreatePrepareLiftingPass( - OpSet target_opset); + tensorflow::quantization::OpSet target_opset); // Lifts the dynamic range quantizable spots as composite functions. std::unique_ptr> CreateLiftQuantizableSpotsAsFunctionsDRQPass( tensorflow::quantization::QuantizationMethod::PresetMethod quantization_method, - OpSet op_set, int min_num_elements_for_weights); + tensorflow::quantization::OpSet op_set, int min_num_elements_for_weights); // Replaces tf.CustomAggregator ops with quant.Stats ops for finalizing the // calibration procedure. @@ -71,7 +71,7 @@ CreateIssueIDsOfCustomAggregationOpsPass(); std::unique_ptr> CreateInsertQuantizedFunctionsPass( tensorflow::quantization::QuantizationMethod::PresetMethod quantization_method, - OpSet target_opset); + tensorflow::quantization::OpSet target_opset); // Inserts custom aggregation operators for the calibration procedure. std::unique_ptr> @@ -86,8 +86,9 @@ CreateInsertCustomAggregationOpsPass( std::unique_ptr> CreateQuantizeCompositeFunctionsPass( tensorflow::quantization::QuantizationMethod::PresetMethod quantization_method, - OpSet target_opset, bool enable_per_channel_quantization, - int min_num_elements_for_weight, bool enable_legacy_weight_only = false, + tensorflow::quantization::OpSet target_opset, + bool enable_per_channel_quantization, int min_num_elements_for_weights, + bool enable_legacy_weight_only = false, std::optional mlir_dump_file_prefix = std::nullopt); @@ -100,7 +101,8 @@ std::unique_ptr> CreateQuantizePass(); // Overloading of CreateQuantizePass which takes QuantizationSpecs. std::unique_ptr> CreateQuantizePass( - QuantizationSpecs quant_specs, OpSet target_opset); + QuantizationSpecs quant_specs, + tensorflow::quantization::OpSet target_opset); // Creates an instance of the PrepareQuantize pass, which will perform similar // transformations as TFL::PrepareQuantizePass. @@ -112,12 +114,13 @@ std::unique_ptr> CreatePrepareQuantizePass( // Creates an instance of the PrepareQuantizeDRQ pass, which will // perform similar transformations as TFL::PrepareQuantizeDynamicRangePass. std::unique_ptr> CreatePrepareQuantizeDRQPass( - const QuantizationSpecs& quant_specs, OpSet op_set); + const QuantizationSpecs& quant_specs, + tensorflow::quantization::OpSet op_set); // Creates an instance of the PreprocessOp pass, which will perform op // preprocessing to allow multi-axis quantization, prior to quantization. std::unique_ptr> CreatePreprocessOpPass( - OpSet op_set, + tensorflow::quantization::OpSet op_set, tensorflow::quantization::QuantizationMethod::PresetMethod quantization_method, bool enable_per_channel_quantization); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc index ebdd374288a065..38075bb67b7010 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc @@ -49,6 +49,8 @@ namespace mlir { namespace quant { namespace { +using ::tensorflow::quantization::OpSet; + class PrepareLiftingPass : public PassWrapper> { public: diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc index 3a42967e6ada1b..fe38ed8dc0f634 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc @@ -34,7 +34,6 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc index af02c3694fc16d..71587390580406 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc @@ -45,7 +45,7 @@ namespace { using QuantizationUnit = std::pair; using QuantizationUnits = llvm::SetVector; -using ::mlir::quant::OpSet; +using ::tensorflow::quantization::OpSet; // Applies prepare quantization on the model in TF dialect for dynamic range // quantization case. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc index 765929a75043aa..3f54fe580fe1c4 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc @@ -14,25 +14,36 @@ limitations under the License. ==============================================================================*/ // This transformation pass applies quantization propagation on TF dialect. -#include +#include #include -#include #include -#include +#include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" //===----------------------------------------------------------------------===// // The preprocess-op Pass. @@ -46,6 +57,7 @@ using QuantMethod = ::tensorflow::quantization::QuantizationMethod::PresetMethod; using QuantizationUnit = std::pair; using QuantizationUnits = llvm::SetVector; +using ::tensorflow::quantization::OpSet; // Preprocesses ops to allow multi-axis quantization, prior to quantization // passes. Currently, per-channel quantization only supports 1D results. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc index ca088c5d318cf4..26e468556a36ab 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc @@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Copied and modified from -// //third_party/tensorflow/compiler/mlir/lite/transforms/quantize.cc -// This transformation pass applies quantization on TF dialect. #include #include #include @@ -44,7 +41,6 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" @@ -60,6 +56,8 @@ namespace quant { //===----------------------------------------------------------------------===// namespace { +using ::tensorflow::quantization::OpSet; + enum QuantizationTrait { kFullQuantization, kDynamicRangeQuantization }; // Base struct for quantization. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc index 2ddb9f50eedee4..0b3c89c56f60bb 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc @@ -62,6 +62,7 @@ namespace quant { namespace { using QuantMethod = tensorflow::quantization::QuantizationMethod::PresetMethod; +using ::tensorflow::quantization::OpSet; constexpr absl::string_view kQuantizeCompositeFunctionsStepName = "_quantize_composite_functions"; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index eb91ce68063308..a7a56a610bec41 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -32,15 +32,20 @@ cc_library( "//tensorflow/python:__pkg__", ], deps = [ + ":py_function_lib", ":unfreeze_constants", "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:config", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:context", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:debugger", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:post_calibration", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:pre_calibration", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:saved_model_export", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:saved_model_import", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:types", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:statistics", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", @@ -55,7 +60,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:path", - "//tensorflow/core/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -83,6 +87,7 @@ cc_library( hdrs = ["quantize_model.h"], compatible_with = get_compatible_with_portable(), deps = [ + ":py_function_lib", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "@com_google_absl//absl/container:flat_hash_map", @@ -169,11 +174,9 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc", - "//tensorflow/core:protos_all_cc", "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/container:flat_hash_map", - #"@com_google_absl//absl/strings:string_view", - "@pybind11", + "@com_google_absl//absl/strings:string_view", ], ) @@ -232,9 +235,6 @@ tf_python_pybind_extension( ":py_function_lib", ":quantize_model_cc", ":type_casters", - "//tensorflow/compiler/mlir/quantization/stablehlo/cc:debugger", - "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", - "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:statistics", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py index a28d7ebe4bf7f3..ec86deac1b497d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py @@ -83,7 +83,6 @@ 'UniformQuantizedDotHybrid', ) -_DebuggerOptions = quant_opts_pb2.DebuggerOptions _DebuggerConfig = stablehlo_quant_config_pb2.DebuggerConfig # Lists of ops whose channel dimension should be changed if per_channel @@ -1179,8 +1178,12 @@ def test_qat_gather_and_conv_model( quantization_options, ) self.assertIsNotNone(converted_model) - self.assertSizeRatioLessThan( - self._output_saved_model_path, self._input_saved_model_path, 0.5 + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.5, ) def test_qat_vocab_table_lookup_model(self): @@ -2017,15 +2020,22 @@ def test_gather_and_conv_model( output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def if target_opset == quant_opts_pb2.UNIFORM_QUANTIZED: - self.assertSizeRatioGreaterThan( - self._output_saved_model_path, self._input_saved_model_path, 0.68 + self.assertGreater( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.68, ) self.assertTrue( self._contains_op(output_graphdef, 'UniformQuantizedConvolution') ) else: - self.assertSizeRatioLessThan( - self._output_saved_model_path, self._input_saved_model_path, 1 / 3 + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 1 / 3, ) if target_opset == quant_opts_pb2.XLA: self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2')) @@ -2976,12 +2986,19 @@ def test_gather_model( ) if expect_quantized_gather: - self.assertSizeRatioLessThan( - self._output_saved_model_path, self._input_saved_model_path, 1 / 3 + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 1 / 3, ) else: - self.assertSizeRatioGreaterThan( - self._output_saved_model_path, self._input_saved_model_path, 2 / 3 + self.assertGreater( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 2 / 3, ) @test_util.run_in_graph_and_eager_modes @@ -3578,7 +3595,9 @@ def test_ptq_multiple_signatures_invalid_dataset_raises_value_error(self): for _ in range(8) ] - with self.assertRaisesRegex(ValueError, 'Invalid representative dataset.'): + with self.assertRaisesRegex( + Exception, 'Representative dataset is not a mapping' + ): quantize_model.quantize( self._input_saved_model_path, output_directory=self._output_saved_model_path, @@ -3933,8 +3952,8 @@ def test_ptq_model_with_tf1_saved_model_invalid_input_key_raises_value_error( ) with self.assertRaisesRegex( - ValueError, - 'Failed to run graph for post-training quantization calibration', + Exception, + 'Invalid input keys for representative sample.', ): quantize_model.quantize( self._input_saved_model_path, @@ -4877,12 +4896,19 @@ def test_gather_model( ) if target_opset == quant_opts_pb2.UNIFORM_QUANTIZED: - self.assertSizeRatioGreaterThan( - self._output_saved_model_path, self._input_saved_model_path, 0.65 + self.assertGreater( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.65, ) else: - self.assertSizeRatioLessThan( - self._output_saved_model_path, self._input_saved_model_path, 1 / 3 + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 1 / 3, ) @parameterized.named_parameters( @@ -4931,8 +4957,11 @@ def test_gather_and_conv_model( output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def if target_opset == quant_opts_pb2.UNIFORM_QUANTIZED: - self.assertSizeRatioGreaterThan( - self._output_saved_model_path, self._input_saved_model_path, 0.65 + self.assertGreater( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.65, ) self.assertTrue( self._contains_op( @@ -4940,8 +4969,12 @@ def test_gather_and_conv_model( ) ) else: - self.assertSizeRatioLessThan( - self._output_saved_model_path, self._input_saved_model_path, 1 / 3 + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 1 / 3, ) if target_opset == quant_opts_pb2.XLA: self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2')) @@ -5097,14 +5130,20 @@ def test_gather_model_tf1( if target_opset == quant_opts_pb2.UNIFORM_QUANTIZED: threshold = 0.45 if use_variable else 0.7 - self.assertSizeRatioGreaterThan( - self._output_saved_model_path, self._input_saved_model_path, threshold + self.assertGreater( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + threshold, ) else: threshold = 0.19 if use_variable else 0.42 - self.assertSizeRatioLessThan( - self._output_saved_model_path, self._input_saved_model_path, threshold + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + threshold, ) @test_util.run_in_graph_and_eager_modes @@ -5358,10 +5397,11 @@ def test_einsum_model( ) ) # Due to other meta data, the compression is not exactly 1/4. - self.assertSizeRatioLessThan( - self._output_saved_model_path, - self._input_saved_model_path, - threshold=0.5, + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.5, ) @parameterized.named_parameters( @@ -5409,10 +5449,11 @@ def test_matmul_model( # Due to other meta data, the compression is not exactly 1/4. self.assertTrue(self._contains_op(output_graphdef, 'XlaDotV2')) - self.assertSizeRatioLessThan( - self._output_saved_model_path, - self._input_saved_model_path, - threshold=0.3, + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.3, ) @parameterized.named_parameters( @@ -5469,10 +5510,11 @@ def test_conv_model( output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def # Due to other meta data, the compression is not exactly 1/4. - self.assertSizeRatioLessThan( - self._output_saved_model_path, - self._input_saved_model_path, - threshold=0.3, + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.3, ) if enable_per_channel_quantization: @@ -5561,10 +5603,11 @@ def test_depthwise_conv2d_model( # Due to other meta data, the compression is not exactly 1/4. size_threshold = 0.5 if enable_per_channel_quantization else 0.32 - self.assertSizeRatioLessThan( - self._output_saved_model_path, - self._input_saved_model_path, - threshold=size_threshold, + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + size_threshold, ) if enable_per_channel_quantization: @@ -5659,8 +5702,12 @@ def test_gather_model( self.assertCountEqual( converted_model.signatures._signatures.keys(), {'serving_default'} ) - self.assertSizeRatioLessThan( - self._output_saved_model_path, self._input_saved_model_path, 0.3 + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.3, ) @parameterized.named_parameters( @@ -5720,8 +5767,12 @@ def test_gather_and_conv_model( ) output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2')) - self.assertSizeRatioLessThan( - self._output_saved_model_path, self._input_saved_model_path, 1 / 3 + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 1 / 3, ) @test_util.run_in_graph_and_eager_modes @@ -5926,7 +5977,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: preset_method=_PresetMethod.METHOD_STATIC_RANGE_INT8 ), op_set=quant_opts_pb2.XLA, - debugger_options=_DebuggerOptions( + debugger_config=_DebuggerConfig( debugger_type=_DebuggerConfig.DebuggerType.DEBUGGER_TYPE_WHOLE_MODEL, unquantized_dump_model_path=unquantized_dump_model_path, log_dir_path=log_dir_path, @@ -6039,7 +6090,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: preset_method=_PresetMethod.METHOD_STATIC_RANGE_INT8 ), op_set=target_opset, - debugger_options=_DebuggerOptions( + debugger_config=_DebuggerConfig( debugger_type=debugger_type, log_dir_path=log_dir_path, ), @@ -6880,8 +6931,11 @@ def test_selective_quantization_on_gather( # The Conv2D op shouldn't be quantized as it has no FakeQuant on input. self.assertTrue(self._contains_op(graphdef, 'Conv2D')) # If the Gather op is quantized, input_model_size / output_model_size > 2. - self.assertSizeRatioLessThan( - self._input_saved_model_path, self._output_saved_model_path, 1.15 + self.assertLess( + testing.get_size_ratio( + self._input_saved_model_path, self._output_saved_model_path + ), + 1.15, ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py index 1aad9b619b61d6..245240e5ebb1be 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py @@ -108,40 +108,6 @@ def _any_log_contains( ) ) - def assertSizeRatioGreaterThan( - self, path_a: str, path_b: str, threshold: float - ): - """Check if the size ratio of the given paths is greater than the threshold. - - Args: - path_a: Path of a directory or a file to be the nominator of the ratio. - path_b: Path of a directory or a file to be the denominator of the ratio. - threshold: a number to compare with. - - Returns: - True if the size ratio of path_a / path_b is greater than threshold. - """ - size_a = self._get_dir_size(path_a) - size_b = self._get_dir_size(path_b) - size_ratio = size_a / size_b - return self.assertGreater(size_ratio, threshold) - - def assertSizeRatioLessThan(self, path_a: str, path_b: str, threshold: float): - """Check if the size ratio of the given paths is less than the threshold. - - Args: - path_a: Path of a directory or a file to be the nominator of the ratio. - path_b: Path of a directory or a file to be the denominator of the ratio. - threshold: a number to compare with. - - Returns: - True if the size ratio of path_a / path_b is less than threshold. - """ - size_a = self._get_dir_size(path_a) - size_b = self._get_dir_size(path_b) - size_ratio = size_a / size_b - return self.assertLess(size_ratio, threshold) - def _is_quantized_function(self, func: function_pb2.FunctionDef) -> bool: """Determine whether a FunctionDef is quantized. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc index 8273279df67787..a0865c44664290 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc @@ -30,9 +30,6 @@ limitations under the License. #include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil #include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep #include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf -#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h" @@ -45,17 +42,13 @@ namespace py = pybind11; namespace { -using ::stablehlo::quantization::AddCalibrationStatistics; -using ::stablehlo::quantization::EnableDebugging; -using ::stablehlo::quantization::io::CreateTmpDir; using ::tensorflow::SignatureDef; using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PyFunctionLibrary; using ::tensorflow::quantization::QuantizationOptions; -using ::tensorflow::quantization::QuantizePtqDynamicRange; -using ::tensorflow::quantization::QuantizePtqModelPostCalibration; -using ::tensorflow::quantization::QuantizePtqModelPreCalibration; +using ::tensorflow::quantization::QuantizeDynamicRangePtq; using ::tensorflow::quantization::QuantizeQatModel; +using ::tensorflow::quantization::QuantizeStaticRangePtq; using ::tensorflow::quantization::QuantizeWeightOnly; using ::tensorflow::quantization::RepresentativeDatasetFile; @@ -89,7 +82,7 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { // Remove the `tpu` tag from the debug quantized saved model as it is // for CPU. Note the 'tpu' value should be the same as `TPU` defined in // tensorflow/python/saved_model/tag_constants.py. - if (quantization_options.has_debugger_options()) { + if (quantization_options.has_debugger_config()) { tags.erase("tpu"); } py_function_library.SaveExportedModel( @@ -132,13 +125,13 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { quantization_options.tags().end()); const absl::StatusOr exported_model = - QuantizePtqDynamicRange(src_saved_model_path, signature_keys, tags, + QuantizeDynamicRangePtq(src_saved_model_path, signature_keys, tags, quantization_options); // Remove the `tpu` tag from the debug quantized saved model as it is // for CPU. Note the 'tpu' value should be the same as `TPU` defined in // tensorflow/python/saved_model/tag_constants.py. - if (quantization_options.has_debugger_options()) { + if (quantization_options.has_debugger_config()) { tags.erase("tpu"); } py_function_library.SaveExportedModel( @@ -222,73 +215,22 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { std::unordered_set tags; tags.insert(quantization_options.tags().begin(), quantization_options.tags().end()); - - absl::StatusOr exported_model = - QuantizePtqModelPreCalibration(src_saved_model_path, signature_keys, - tags, quantization_options); + const absl::StatusOr exported_model = + QuantizeStaticRangePtq(src_saved_model_path, signature_keys, tags, + quantization_options, signature_def_map, + py_function_library, + representative_dataset_file_map_serialized); if (!exported_model.ok()) return exported_model.status(); - const absl::StatusOr precalibrated_saved_model_dir = - CreateTmpDir(); - if (!precalibrated_saved_model_dir.ok()) { - throw py::value_error( - precalibrated_saved_model_dir.status().ToString()); - } - - py_function_library.SaveExportedModel( - *precalibrated_saved_model_dir, *exported_model, - src_saved_model_path, tags, signature_def_map); - - py_function_library.RunCalibration( - *precalibrated_saved_model_dir, signature_keys, tags, - quantization_options.calibration_options(), - quantization_options.force_graph_mode_calibration(), - representative_dataset_file_map_serialized); - - if (absl::Status status = AddCalibrationStatistics( - *exported_model->mutable_graph_def(), - quantization_options.calibration_options(), - py_function_library); - !status.ok()) { - LOG(WARNING) << "Some CustomAggregator ops do not have min or max " - "values. Parts of the graph are not quantized. " - << status; - } - - if (quantization_options.has_debugger_options()) { - EnableDebugging(*exported_model, - quantization_options.debugger_options(), - py_function_library, src_saved_model_path, tags, - signature_def_map); - } - - const absl::StatusOr calibrated_saved_model_path = - CreateTmpDir(); - if (!calibrated_saved_model_path.ok()) { - throw py::value_error( - calibrated_saved_model_path.status().ToString()); - } - - py_function_library.SaveExportedModel( - *calibrated_saved_model_path, *exported_model, src_saved_model_path, - tags, signature_def_map); - - const absl::StatusOr post_calibrated_exported_model = - QuantizePtqModelPostCalibration(*calibrated_saved_model_path, - signature_keys, tags, - quantization_options); - if (!post_calibrated_exported_model.ok()) - return post_calibrated_exported_model.status(); - // Remove the `tpu` tag from the debug quantized saved model as it is // for CPU. Note the 'tpu' value should be the same as `TPU` defined in // tensorflow/python/saved_model/tag_constants.py. - if (quantization_options.has_debugger_options()) { + if (quantization_options.has_debugger_config()) { tags.erase("tpu"); } py_function_library.SaveExportedModel( - dst_saved_model_path, *post_calibrated_exported_model, - *calibrated_saved_model_path, tags, signature_def_map); + dst_saved_model_path, *exported_model, src_saved_model_path, tags, + signature_def_map); return absl::OkStatus(); }, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index 08b71190bbb5b5..89467d30944ca9 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -34,18 +34,23 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/context.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h" @@ -54,7 +59,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/saver.pb.h" #include "tsl/platform/errors.h" @@ -64,7 +68,6 @@ namespace tensorflow { namespace quantization { namespace { -using ::mlir::quant::stablehlo::AddExportPasses; using ::mlir::quant::stablehlo::ConvertMlirModuleToExportedModel; using ::mlir::quant::stablehlo::CreateMlirContextForQuantization; using ::mlir::quant::stablehlo::ExportOptions; @@ -75,21 +78,17 @@ using ::mlir::quant::stablehlo::kExportStepSuffix; using ::mlir::quant::stablehlo::PostCalibrationComponent; using ::mlir::quant::stablehlo::PreCalibrationComponent; using ::mlir::quant::stablehlo::UpdateFunctionAliases; +using ::stablehlo::quantization::AddCalibrationStatistics; +using ::stablehlo::quantization::ChangeToQuantizedFilename; using ::stablehlo::quantization::DebuggerConfig; +using ::stablehlo::quantization::DisableDebugging; +using ::stablehlo::quantization::EnableDebugging; +using ::stablehlo::quantization::ExpandPresets; +using ::stablehlo::quantization::PopulateDefaults; using ::stablehlo::quantization::QuantizationConfig; +using ::stablehlo::quantization::io::CreateTmpDir; using ::stablehlo::quantization::io::GetLocalTmpFileName; - -// TODO: b/326355110 - Removes `ConvertDebuggerOptionToDebuggerConfig` when -// merging `DebuggingOption` to `DebuggingConfig`. -DebuggerConfig ConvertDebuggerOptionToDebuggerConfig( - const DebuggerOptions &debugger_options) { - DebuggerConfig debugger_config; - debugger_config.set_debugger_type(debugger_options.debugger_type()); - debugger_config.set_unquantized_dump_model_path( - debugger_options.unquantized_dump_model_path()); - debugger_config.set_log_dir_path(debugger_options.log_dir_path()); - return debugger_config; -} +using ::tensorflow::quantization::PyFunctionLibrary; absl::StatusOr> ImportAndPreprocessSavedModel( absl::string_view saved_model_path, @@ -135,41 +134,6 @@ absl::StatusOr> ImportAndPreprocessSavedModel( return module_ref; } -// Sets up and runs the passes for exporting `module_op`. The behavior of the -// exporting passes is controlled by `export_opts`. Returns `AssetFileDef`s that -// associate the input arguments of @main and the asset file names. Asset file -// names will be used to feed the corresponding tensors during initialization -// upon model loading. -absl::StatusOr> RunExportPasses( - const ExportOptions &export_opts, mlir::MLIRContext &ctx, - mlir::ModuleOp module_op) { - if (export_opts.unfreeze_constants) { - TF_RETURN_IF_ERROR(UnfreezeConstantsAndSaveVariables( - export_opts.checkpoint_dir, ctx, module_op)); - LOG(INFO) << "Unfrozen constants and saved variables to checkpoint file: " - << export_opts.checkpoint_dir; - } - - if (absl::Status pass_run_status = RunPasses( - /*name=*/ - export_opts.debug_name, - /*add_passes_func=*/ - [dup_constants = export_opts.duplicate_shape_determining_constants]( - mlir::PassManager &pm) { AddExportPasses(pm, dup_constants); }, - ctx, module_op); - !pass_run_status.ok()) { - return pass_run_status; - } - - mlir::FailureOr> asset_file_defs = - mlir::quant::ConvertAssetArgs(module_op); - if (failed(asset_file_defs)) { - return absl::InternalError("Failed to convert asset args."); - } - - return *asset_file_defs; -} - absl::StatusOr ModuleOpToExportedModel( mlir::ModuleOp module_op, mlir::MLIRContext *ctx, absl::string_view step_name, const bool unfreeze_constants, @@ -189,90 +153,63 @@ absl::StatusOr ModuleOpToExportedModel( {asset_file_defs.begin(), asset_file_defs.end()}); } -} // namespace - -absl::StatusOr QuantizeQatModel( - const absl::string_view saved_model_path, - const std::vector &signature_keys, - const std::unordered_set &tags, - const QuantizationOptions &quantization_options) { - // Convert the SavedModelBundle to an MLIR module. - std::unique_ptr context = - CreateMlirContextForQuantization(); +absl::StatusOr ExportCalibrationModel( + mlir::ModuleOp module_op, mlir::MLIRContext *context, + const QuantizationOptions &quantization_options, + const absl::flat_hash_map &function_aliases) { + // Clone ModuleOp and function aliases so changes in this pipeline won't + // be reflected in the original values. + mlir::OwningOpRef cloned_module_ref(module_op.clone()); - absl::StatusOr> - function_aliases = GetFunctionAliases(saved_model_path, tags); - if (!function_aliases.ok()) { - return absl::InternalError(absl::StrCat( - "Failed to get function alias: ", function_aliases.status().message())); - } + // Disable DumpTensor ops when running calibration. + DisableDebugging(*cloned_module_ref); - absl::StatusOr> module = - ImportAndPreprocessSavedModel( - saved_model_path, signature_keys, tags, context.get(), - /*is_inliner_run=*/true, - /*run_tf_to_stablehlo=*/false, - /*deserialize_xla_call_module=*/false, *function_aliases); - if (!module.status().ok()) { + absl::StatusOr exported_model = ModuleOpToExportedModel( + *cloned_module_ref, context, kTfQuantPtqPreCalibrationStepName, + /*unfreeze_constants=*/!quantization_options.freeze_all_variables(), + function_aliases); + if (!exported_model.status().ok()) { return absl::InternalError( - absl::StrCat("Failed to import and preprocess SavedModel: ", - module.status().message())); + absl::StrCat("Failed to export calibration model: ", + exported_model.status().message())); } - mlir::OwningOpRef module_ref = std::move(module).value(); - - TF_RETURN_IF_ERROR(RunPasses( - /*name=*/ - kTfQuantQatStepName, /*add_passes_func=*/ - [&quantization_options](mlir::PassManager &pm) { - AddQuantizeQatPasses(pm, quantization_options, kTfQuantQatStepName); - }, - *context, *module_ref)); - return ModuleOpToExportedModel( - *module_ref, context.get(), kTfQuantQatStepName, - /*unfreeze_constants=*/!quantization_options.freeze_all_variables(), - *function_aliases); + return *exported_model; } -absl::StatusOr QuantizePtqModelPreCalibration( - const absl::string_view saved_model_path, - const std::vector &signature_keys, - const std::unordered_set &tags, +QuantizationConfig GetQuantizationConfigForStaticRangePtq( const QuantizationOptions &quantization_options) { - // Convert the SavedModelBundle to an MLIR module. - std::unique_ptr context = - CreateMlirContextForQuantization(); - - absl::StatusOr> - function_aliases = GetFunctionAliases(saved_model_path, tags); - if (!function_aliases.ok()) { - return absl::InternalError(absl::StrCat( - "Failed to get function alias: ", function_aliases.status().message())); - } + QuantizationConfig quantization_config{}; + // TODO: b/331302857 - Remove `enable_per_channel_quantized_weight` usage. + quantization_config.mutable_static_range_ptq_preset() + ->set_enable_per_channel_quantized_weight( + quantization_options.enable_per_channel_quantization()); + // When targeting server TPUs quantized types should be unpacked into + // integer ops. + quantization_config.mutable_pipeline_config()->set_unpack_quantized_types( + true); + *quantization_config.mutable_debugger_config() = + quantization_options.debugger_config(); + quantization_config.mutable_static_range_ptq_preset(); + *quantization_config.mutable_calibration_options() = + quantization_options.calibration_options(); + + return ExpandPresets(PopulateDefaults(quantization_config)); +} +absl::StatusOr QuantizePtqModelPreCalibrationImpl( + mlir::ModuleOp module_op, mlir::MLIRContext *context, + const QuantizationOptions &quantization_options, + const absl::flat_hash_map &function_aliases) { const bool is_stablehlo = quantization_options.op_set() == OpSet::STABLEHLO; - absl::StatusOr> module = - ImportAndPreprocessSavedModel( - saved_model_path, signature_keys, tags, context.get(), - /*is_inliner_run=*/true, - /*run_tf_to_stablehlo=*/is_stablehlo, - /*deserialize_xla_call_module=*/false, *function_aliases); - if (!module.status().ok()) { - return absl::InternalError( - absl::StrCat("Failed to import and preprocess SavedModel: ", - module.status().message())); - } - mlir::OwningOpRef module_ref = std::move(module).value(); - // Use StableHLO Quantizer option if opset is specified. if (is_stablehlo) { - QuantizationConfig quantization_config; - *quantization_config.mutable_debugger_config() = - ConvertDebuggerOptionToDebuggerConfig( - quantization_options.debugger_options()); - PreCalibrationComponent pre_calibration_component(context.get()); - TF_ASSIGN_OR_RETURN(*module_ref, pre_calibration_component.Run( - *module_ref, quantization_config)); + const QuantizationConfig quantization_config = + GetQuantizationConfigForStaticRangePtq(quantization_options); + + PreCalibrationComponent pre_calibration_component(context); + TF_ASSIGN_OR_RETURN(module_op, pre_calibration_component.Run( + module_op, quantization_config)); } else { TF_RETURN_IF_ERROR(RunPasses( /*name=*/ @@ -280,17 +217,47 @@ absl::StatusOr QuantizePtqModelPreCalibration( [&quantization_options](mlir::PassManager &pm) { AddQuantizePtqPreCalibrationPasses(pm, quantization_options); }, - *context, *module_ref)); + *context, module_op)); + } + + return ExportCalibrationModel(module_op, context, quantization_options, + function_aliases); +} + +absl::StatusOr QuantizePtqModelPostCalibrationImpl( + mlir::ModuleOp module_op, mlir::MLIRContext *context, + const QuantizationOptions &quantization_options, + const absl::flat_hash_map &function_aliases) { + const bool is_stablehlo = quantization_options.op_set() == OpSet::STABLEHLO; + // Use StableHLO Quantizer option if opset is specified. + if (is_stablehlo) { + const QuantizationConfig quantization_config = + GetQuantizationConfigForStaticRangePtq(quantization_options); + + PostCalibrationComponent post_calibration_component(context); + TF_ASSIGN_OR_RETURN(module_op, post_calibration_component.Run( + module_op, quantization_config)); + } else { + TF_RETURN_IF_ERROR(RunPasses( + /*name=*/ + kTfQuantPtqPostCalibrationStepName, /*add_passes_func=*/ + [&quantization_options](mlir::PassManager &pm) { + AddQuantizePtqPostCalibrationPasses( + pm, quantization_options, kTfQuantPtqPostCalibrationStepName); + }, + *context, module_op)); } return ModuleOpToExportedModel( - *module_ref, context.get(), kTfQuantPtqPreCalibrationStepName, + module_op, context, kTfQuantPtqPostCalibrationStepName, /*unfreeze_constants=*/!quantization_options.freeze_all_variables(), - *function_aliases); + function_aliases); } -absl::StatusOr QuantizePtqModelPostCalibration( - const absl::string_view saved_model_path, +} // namespace + +absl::StatusOr QuantizeQatModel( + absl::string_view saved_model_path, const std::vector &signature_keys, const std::unordered_set &tags, const QuantizationOptions &quantization_options) { @@ -304,16 +271,12 @@ absl::StatusOr QuantizePtqModelPostCalibration( "Failed to get function alias: ", function_aliases.status().message())); } - const bool is_stablehlo = quantization_options.op_set() == OpSet::STABLEHLO; - // Freezing is required again since variables might have been produced during - // the pre-calibration step. `is_inliner_run = false` to prevent the functions - // lifted for quantization from being inlined. absl::StatusOr> module = ImportAndPreprocessSavedModel( saved_model_path, signature_keys, tags, context.get(), - /*is_inliner_run=*/false, + /*is_inliner_run=*/true, /*run_tf_to_stablehlo=*/false, - /*deserialize_xla_call_module=*/is_stablehlo, *function_aliases); + /*deserialize_xla_call_module=*/false, *function_aliases); if (!module.status().ok()) { return absl::InternalError( absl::StrCat("Failed to import and preprocess SavedModel: ", @@ -321,39 +284,22 @@ absl::StatusOr QuantizePtqModelPostCalibration( } mlir::OwningOpRef module_ref = std::move(module).value(); - // Use StableHLO Quantizer option if opset is specified. - if (is_stablehlo) { - QuantizationConfig quantization_config{}; - quantization_config.mutable_static_range_ptq_preset() - ->set_enable_per_channel_quantized_weight( - quantization_options.enable_per_channel_quantization()); - // When targeting server TPUs quantized types should be unpacked into - // integer ops. - quantization_config.mutable_pipeline_config()->set_unpack_quantized_types( - true); - - PostCalibrationComponent post_calibration_component(context.get()); - TF_ASSIGN_OR_RETURN(*module_ref, post_calibration_component.Run( - *module_ref, quantization_config)); - } else { - TF_RETURN_IF_ERROR(RunPasses( - /*name=*/ - kTfQuantPtqPostCalibrationStepName, /*add_passes_func=*/ - [&quantization_options](mlir::PassManager &pm) { - AddQuantizePtqPostCalibrationPasses( - pm, quantization_options, kTfQuantPtqPostCalibrationStepName); - }, - *context, *module_ref)); - } + TF_RETURN_IF_ERROR(RunPasses( + /*name=*/ + kTfQuantQatStepName, /*add_passes_func=*/ + [&quantization_options](mlir::PassManager &pm) { + AddQuantizeQatPasses(pm, quantization_options, kTfQuantQatStepName); + }, + *context, *module_ref)); return ModuleOpToExportedModel( - *module_ref, context.get(), kTfQuantPtqPostCalibrationStepName, + *module_ref, context.get(), kTfQuantQatStepName, /*unfreeze_constants=*/!quantization_options.freeze_all_variables(), *function_aliases); } -absl::StatusOr QuantizePtqDynamicRange( - const absl::string_view saved_model_path, +absl::StatusOr QuantizeDynamicRangePtq( + absl::string_view saved_model_path, const std::vector &signature_keys, const std::unordered_set &tags, const QuantizationOptions &quantization_options) { @@ -373,13 +319,11 @@ absl::StatusOr QuantizePtqDynamicRange( /*is_inliner_run=*/true, /*run_tf_to_stablehlo=*/false, /*deserialize_xla_call_module=*/false, *function_aliases); - if (!module.status().ok()) { return absl::InternalError( absl::StrCat("Failed to import and preprocess SavedModel: ", module.status().message())); } - mlir::OwningOpRef module_ref = std::move(module).value(); TF_RETURN_IF_ERROR(RunPasses( @@ -400,7 +344,7 @@ absl::StatusOr QuantizePtqDynamicRange( // TODO: b/297626257 - [Converter Component][TF-Quantizer] Clean up // quantize_model.cc by factoring out repeated codes absl::StatusOr QuantizeWeightOnly( - const absl::string_view saved_model_path, + absl::string_view saved_model_path, const QuantizationOptions &quantization_options) { std::unique_ptr context = CreateMlirContextForQuantization(); @@ -423,13 +367,11 @@ absl::StatusOr QuantizeWeightOnly( quantization_options.tags().end()}, context.get(), /*is_inliner_run=*/true, /*run_tf_to_stablehlo=*/false, /*deserialize_xla_call_module=*/false, *function_aliases); - if (!module.status().ok()) { return absl::InternalError( absl::StrCat("Failed to import and preprocess SavedModel: ", module.status().message())); } - mlir::OwningOpRef module_ref = std::move(module).value(); TF_RETURN_IF_ERROR(RunPasses( @@ -447,5 +389,90 @@ absl::StatusOr QuantizeWeightOnly( *function_aliases); } +absl::StatusOr QuantizeStaticRangePtq( + absl::string_view saved_model_path, + const std::vector &signature_keys, + const std::unordered_set &tags, + const QuantizationOptions &quantization_options, + const absl::flat_hash_map &signature_def_map, + const PyFunctionLibrary &py_function_library, + const absl::flat_hash_map + &representative_dataset_file_map_serialized) { + std::unique_ptr context = + CreateMlirContextForQuantization(); + + absl::StatusOr> + function_aliases = GetFunctionAliases(saved_model_path, tags); + if (!function_aliases.ok()) { + return absl::InternalError(absl::StrCat( + "Failed to get function alias: ", function_aliases.status().message())); + } + + const bool is_stablehlo = quantization_options.op_set() == OpSet::STABLEHLO; + absl::StatusOr> module = + ImportAndPreprocessSavedModel( + saved_model_path, signature_keys, tags, context.get(), + /*is_inliner_run=*/true, + /*run_tf_to_stablehlo=*/is_stablehlo, + /*deserialize_xla_call_module=*/false, *function_aliases); + if (!module.status().ok()) { + return absl::InternalError( + absl::StrCat("Failed to import and preprocess SavedModel: ", + module.status().message())); + } + mlir::OwningOpRef module_ref = std::move(module).value(); + + TF_ASSIGN_OR_RETURN( + absl::StatusOr pre_calibration_exported_model, + QuantizePtqModelPreCalibrationImpl( + *module_ref, context.get(), quantization_options, *function_aliases)); + + TF_ASSIGN_OR_RETURN( + const absl::StatusOr precalibrated_saved_model_dir, + CreateTmpDir()); + + py_function_library.SaveExportedModel( + *precalibrated_saved_model_dir, *pre_calibration_exported_model, + saved_model_path, tags, signature_def_map); + + py_function_library.RunCalibration( + *precalibrated_saved_model_dir, signature_keys, tags, + quantization_options.calibration_options(), + quantization_options.force_graph_mode_calibration(), + representative_dataset_file_map_serialized); + + if (absl::Status status = AddCalibrationStatistics( + *module_ref, quantization_options.calibration_options(), + py_function_library); + !status.ok()) { + LOG(WARNING) << "Some CustomAggregator ops do not have min or max " + "values. Parts of the graph are not quantized. " + << status; + } + + // Saves the current model to the `unquantized_dump_model_path` if the + // debugger type is `DEBUGGER_TYPE_WHOLE_MODEL`. This is required + // because in whole-model debugging mode the `DumpTensor` ops for the + // unquantized tensors are only inserted in the unquantized model + // whereas `DumpTensor` ops for the quantized tensors are only inserted + // in the quantized model. Both models are required to be able to dump + // both quantized and unquantized tensors and compare them offline. + if (quantization_options.has_debugger_config() && + quantization_options.debugger_config().debugger_type() == + DebuggerConfig::DEBUGGER_TYPE_WHOLE_MODEL) { + EnableDebugging(*pre_calibration_exported_model); + ChangeToQuantizedFilename(*module_ref); + + absl::string_view unquantized_dump_model_path = + quantization_options.debugger_config().unquantized_dump_model_path(); + py_function_library.SaveExportedModel( + unquantized_dump_model_path, *pre_calibration_exported_model, + saved_model_path, tags, signature_def_map); + } + + return QuantizePtqModelPostCalibrationImpl( + *module_ref, context.get(), quantization_options, *function_aliases); +} + } // namespace quantization } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h index 556086ce018123..a54e988c043aa3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" namespace tensorflow { @@ -46,28 +47,28 @@ absl::StatusOr QuantizeQatModel( const std::unordered_set& tags, const QuantizationOptions& quantization_options); -// Apply post-training dynamic range quantization to the model. -absl::StatusOr QuantizePtqDynamicRange( +// Applies post-training dynamic-range quantization to the model. +absl::StatusOr QuantizeDynamicRangePtq( absl::string_view saved_model_path, const std::vector& signature_keys, const std::unordered_set& tags, const QuantizationOptions& quantization_options); +// Applies post-training static-range weight-only quantization to the model. absl::StatusOr QuantizeWeightOnly( absl::string_view saved_model_path, const QuantizationOptions& quantization_options); -absl::StatusOr QuantizePtqModelPreCalibration( +// Applies post-training static-range quantization to the model. +absl::StatusOr QuantizeStaticRangePtq( absl::string_view saved_model_path, const std::vector& signature_keys, const std::unordered_set& tags, - const QuantizationOptions& quantization_options); - -absl::StatusOr QuantizePtqModelPostCalibration( - absl::string_view saved_model_path, - const std::vector& signature_keys, - const std::unordered_set& tags, - const QuantizationOptions& quantization_options); + const QuantizationOptions& quantization_options, + const absl::flat_hash_map& signature_def_map, + const PyFunctionLibrary& py_function_library, + const absl::flat_hash_map& + representative_dataset_file_map_serialized); } // namespace quantization } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index 1bf3fe81c7d8ba..e0eeca13d92f20 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -59,6 +59,8 @@ _QuantizationComponent = _QuantizationComponentSpec.QuantizationComponent _TensorType = _QuantizationComponentSpec.TensorType +_RepresentativeDatasetFile = quant_opts_pb2.RepresentativeDatasetFile + # Mapping of signature def key -> SignatureDef. _SignatureDefMap = Mapping[str, meta_graph_pb2.SignatureDef] @@ -99,6 +101,57 @@ def _serialize_signature_def_map( return signature_def_map_serialized +def _save_representative_dataset( + representative_dataset: repr_dataset.RepresentativeDatasetOrMapping, + signature_def_map: _SignatureDefMap, +) -> Mapping[str, _RepresentativeDatasetFile]: + """Saves the representative dataset to temporary TFRecord files. + + Args: + representative_dataset: Representative dataset used for the calibration + step. Representative datasets should exist for each signature def key in + `signature_def_keys`. + signature_def_map: Signature def key -> SignatureDef mapping. + + Returns: + A map from signature key to the saved representative dataset file. + """ + if isinstance(representative_dataset, Mapping): + if set(signature_def_map.keys()) != set(representative_dataset.keys()): + raise ValueError( + 'The signature keys and the keys of representative dataset map ' + f'do not match. Signature keys: {set(signature_def_map.keys())}, ' + f'representative dataset map: {set(representative_dataset.keys())}.' + ) + representative_dataset_map = representative_dataset + elif len(signature_def_map.keys()) > 1: + raise ValueError( + 'Representative dataset is not a mapping (got: ' + f'{type(representative_dataset)}), but there is more than one ' + 'signature key provided. Please provide a map of ' + '{signature_key -> dataset} with more than one signature key.' + ) + else: + representative_dataset_map = { + list(signature_def_map.keys())[0]: representative_dataset, + } + + # Save the representative dataset to temporary TFRecord files. + path_map = {} + expected_input_key_map = {} + for signature_key, signature_def in signature_def_map.items(): + # Filepath is the second return value of mkstemp. + _, path_map[signature_key] = tempfile.mkstemp( + suffix='.tfrecord', prefix=signature_key + ) + expected_input_key_map[signature_key] = signature_def.inputs.keys() + + return repr_dataset.TfRecordRepresentativeDatasetSaver( + path_map=path_map, + expected_input_key_map=expected_input_key_map, + ).save(representative_dataset_map) + + def _run_static_range_qat( src_saved_model_path: str, dst_saved_model_path: str, @@ -133,7 +186,7 @@ def _run_static_range_ptq( src_saved_model_path: str, dst_saved_model_path: str, quant_opts: _QuantizationOptions, - representative_dataset: repr_dataset.RepresentativeDatasetOrMapping, + representative_dataset: Mapping[str, _RepresentativeDatasetFile], signature_def_map: _SignatureDefMap, ) -> None: """Runs static-range Post-Training Quantization. @@ -147,9 +200,8 @@ def _run_static_range_ptq( src_saved_model_path: Path to the source SavedModel directory. dst_saved_model_path: Path to the destination SavedModel directory. quant_opts: Quantization options. - representative_dataset: Representative dataset used for the calibration - step. Representative datasets should exist for each signature def key in - `signature_def_keys`. + representative_dataset: A map from signature key to the saved representative + dataset file. signature_def_map: Signature def key -> SignatureDef mapping. Raises: @@ -159,29 +211,11 @@ def _run_static_range_ptq( signature_def_map_serialized = _serialize_signature_def_map(signature_def_map) - if isinstance(representative_dataset, Mapping): - representative_dataset_map = representative_dataset - else: - representative_dataset_map = { - list(signature_def_map.keys())[0]: representative_dataset, - } - - # Save the representative dataset to temporary TFRecord files. - path_map = {} - for signature_key in representative_dataset_map.keys(): - path_map[signature_key] = tempfile.mkstemp( - suffix='.tfrecord', prefix=signature_key - )[1] # Filepath. - - dataset_file_map = repr_dataset.TfRecordRepresentativeDatasetSaver( - path_map - ).save(representative_dataset_map) - # `quantize_ptq_static_range` requires `RepresentativeDatasetFile`s to be # serialized. Serialize the values to match the type. dataset_file_map_serialized = { signature_key: dataset_file.SerializeToString() - for signature_key, dataset_file in dataset_file_map.items() + for signature_key, dataset_file in representative_dataset.items() } pywrap_quantize_model.quantize_ptq_static_range( src_saved_model_path, @@ -246,9 +280,24 @@ def _static_range_quantize( set(quantization_options.tags), ) + if ( + representative_dataset is not None + and quantization_options.representative_datasets + ): + raise ValueError( + 'Do not specify both the `representative_dataset` argument and' + ' the `representative_datasets` field in `QuantizationOptions`.' + ) + + saved_representative_dataset = quantization_options.representative_datasets + if representative_dataset is not None: + saved_representative_dataset = _save_representative_dataset( + representative_dataset, signature_def_map + ) + # Checks if the model is from QAT or method is METHOD_NO_QUANTIZE. if ( - representative_dataset is None + not saved_representative_dataset and not is_qat_saved_model_or_method_no_quantize ): raise ValueError( @@ -274,7 +323,7 @@ def _static_range_quantize( src_saved_model_path, dst_saved_model_path, quantization_options, - representative_dataset, + saved_representative_dataset, signature_def_map, ) @@ -692,7 +741,7 @@ def _populate_quantization_options_default_values( ' quantization via TF Quantizer.' ) - if quantization_options.HasField('debugger_options'): + if quantization_options.HasField('debugger_config'): # Set `force_graph_mode_calibration` to True to avoid skipping op execution, # which are not connected to return ops, during calibration execution. # Setting `force_graph_mode_calibration` to True enables execution of the @@ -704,11 +753,11 @@ def _populate_quantization_options_default_values( ) quantization_options.force_graph_mode_calibration = True - if not quantization_options.debugger_options.log_dir_path: - quantization_options.debugger_options.log_dir_path = '/tmp/dumps' + if not quantization_options.debugger_config.log_dir_path: + quantization_options.debugger_config.log_dir_path = '/tmp/dumps' if ( - quantization_options.debugger_options.debugger_type + quantization_options.debugger_config.debugger_type == stablehlo_quant_config_pb2.DebuggerConfig.DebuggerType.DEBUGGER_TYPE_UNSPECIFIED ): raise ValueError( @@ -716,9 +765,9 @@ def _populate_quantization_options_default_values( ) if ( - quantization_options.debugger_options.debugger_type + quantization_options.debugger_config.debugger_type == stablehlo_quant_config_pb2.DebuggerConfig.DebuggerType.DEBUGGER_TYPE_WHOLE_MODEL - and not quantization_options.debugger_options.unquantized_dump_model_path + and not quantization_options.debugger_config.unquantized_dump_model_path ): raise ValueError( 'Debugger type whole model verify was used but' @@ -840,20 +889,6 @@ def quantize( _populate_quantization_options_default_values(quantization_options) - if ( - representative_dataset is not None - and quantization_options.representative_datasets - ): - raise ValueError( - 'Do not specify both the `representative_dataset` argument and' - ' the `representative_datasets` field in `QuantizationOptions`.' - ) - - if quantization_options.representative_datasets: - representative_dataset = repr_dataset.TfRecordRepresentativeDatasetLoader( - quantization_options.representative_datasets - ).load() - method: _QuantizationMethod = quantization_options.quantization_method if ( method.preset_method == _PresetMethod.METHOD_STATIC_RANGE_INT8 diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py index fabda2ebad3397..c18358745866b4 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py @@ -14,7 +14,7 @@ # ============================================================================== """Defines types required for representative datasets for quantization.""" -import collections.abc +from collections.abc import Collection, Sized import os from typing import Iterable, Mapping, Optional, Union @@ -117,7 +117,11 @@ class TfRecordRepresentativeDatasetSaver(RepresentativeDatasetSaver): ``` """ - def __init__(self, path_map: Mapping[str, os.PathLike[str]]): + def __init__( + self, + path_map: Mapping[str, os.PathLike[str]], + expected_input_key_map: Optional[Mapping[str, Collection[str]]] = None, + ): """Initializes TFRecord represenatative dataset saver. Args: @@ -125,8 +129,22 @@ def __init__(self, path_map: Mapping[str, os.PathLike[str]]): to which a `RepresentativeDataset` is saved. The signature def keys should be a subset of the `SignatureDef` keys of the `representative_dataset` argument of the `save()` call. + expected_input_key_map: Signature def key -> expected input keys. If set, + validate that the sample has same set of input keys before saving. + + Raises: + KeyError: If path_map and expected_input_key_map have different keys. """ self.path_map: Mapping[str, os.PathLike[str]] = path_map + self.expected_input_key_map: Mapping[str, Collection[str]] = {} + if expected_input_key_map is not None: + if set(path_map.keys()) != set(expected_input_key_map.keys()): + raise KeyError( + 'The `path_map` and `expected_input_key_map` should have the same' + ' set of keys.' + ) + + self.expected_input_key_map = expected_input_key_map def _save_tf_record_dataset( self, @@ -143,6 +161,10 @@ def _save_tf_record_dataset( Returns: a RepresentativeDatasetFile instance contains the path to the saved file. + + Raises: + KeyError: If the set of input keys in the dataset samples doesn't match + the set of expected input keys. """ # When running in graph mode (TF1), tf.Tensor types should be converted to # numpy ndarray types to be compatible with `make_tensor_proto`. @@ -150,9 +172,23 @@ def _save_tf_record_dataset( with session.Session() as sess: repr_ds = replace_tensors_by_numpy_ndarrays(repr_ds, sess) + expected_input_keys = self.expected_input_key_map.get( + signature_def_key, None + ) tfrecord_file_path = self.path_map[signature_def_key] with python_io.TFRecordWriter(tfrecord_file_path) as writer: for repr_sample in repr_ds: + if ( + expected_input_keys is not None + and set(repr_sample.keys()) != expected_input_keys + ): + raise KeyError( + 'Invalid input keys for representative sample. The function' + f' expects input keys of: {set(expected_input_keys)}. Got:' + f' {set(repr_sample.keys())}. Please provide correct input keys' + ' for representative samples.' + ) + sample = _RepresentativeDataSample() for input_name, input_value in repr_sample.items(): sample.tensor_proto_inputs[input_name].CopyFrom( @@ -317,7 +353,7 @@ def get_num_samples(repr_ds: RepresentativeDataset) -> Optional[int]: is malformed; it simply means the size cannot be determined without iterating the whole dataset. """ - if isinstance(repr_ds, collections.abc.Sized): + if isinstance(repr_ds, Sized): try: return len(repr_ds) except Exception as ex: # pylint: disable=broad-except diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto b/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto index 13d3876500fe0d..d2c79b6ce4c668 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto @@ -145,21 +145,6 @@ message RepresentativeDatasetFile { } } -// Configuration for quantization debugger. -// NEXT ID: 4 -message DebuggerOptions { - // Type of quantization debugger. Depending on the type, inputs and outputs - // are wired differently. - stablehlo.quantization.DebuggerConfig.DebuggerType debugger_type = 1; - - // Path to save unquantized model with dump tensor ops attached. - // Used when debugger_type is WHOLE_MODEL. - string unquantized_dump_model_path = 2; - - // Path to save debugger related logs. Defaults to '/tmp/dumps'. - string log_dir_path = 3; -} - // Defines various options to specify and control the behavior of the quantizer. // It consists of // 1) Model-wise quantization configuration as a default configuration. If it is @@ -251,7 +236,7 @@ message QuantizationOptions { stablehlo.quantization.CalibrationOptions calibration_options = 15; // Configuration related to quantization debugger. - DebuggerOptions debugger_options = 16; + stablehlo.quantization.DebuggerConfig debugger_config = 16; reserved 3; } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc index 0d5e43cd6f334e..0e756021844a5c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc @@ -149,10 +149,10 @@ void AddQuantizePtqPreCalibrationPasses( pm.addPass(mlir::quant::CreateLiftQuantizableSpotsAsFunctionsPass( quantization_options)); // TODO: b/295140328 - Add debugger support for weight only - if (quantization_options.has_debugger_options()) { + if (quantization_options.has_debugger_config()) { pm.addPass(mlir::quant::CreateAddDumpTensorOpPass( - quantization_options.debugger_options().debugger_type(), - quantization_options.debugger_options().log_dir_path())); + quantization_options.debugger_config().debugger_type(), + quantization_options.debugger_config().log_dir_path())); } pm.addNestedPass( mlir::quant::CreateInsertCustomAggregationOpsPass( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op.mlir index fe12e5935a8791..d50f28941f4269 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op.mlir @@ -29,7 +29,7 @@ module { // WholeModel-DAG: %[[b:.*]] = "tf.Const"() <{value = dense<[-2.000000e+00, 3.000000e+00 // WholeModel-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}> // WholeModel-DAG: %[[output1:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"} -// WholeModel-DAG: "tf.DumpTensor"(%[[output1]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () +// WholeModel-DAG: "tf.DumpTensor"(%[[output1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () // WholeModel-DAG: return %[[output0]], %[[output1]] // IntPerLayer-LABEL: func @conv @@ -38,8 +38,8 @@ module { // IntPerLayer-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2} // IntPerLayer-DAG: %[[output1_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"} // IntPerLayer-DAG: %[[output1_unquantized:.*]] = "tf.PartitionedCall"(%arg0, %cst, %cst_0) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_0} -// IntPerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () -// IntPerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () // IntPerLayer-DAG: return %[[output0]], %[[output1_quantized]] // FloatPerLayer-LABEL: func @conv @@ -48,8 +48,8 @@ module { // FloatPerLayer-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2} // FloatPerLayer-DAG: %[[output1_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"} // FloatPerLayer-DAG: %[[output1_unquantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_0} -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () // FloatPerLayer-DAG: return %[[output0]], %[[output1_unquantized]] } @@ -86,9 +86,9 @@ module { // WholeModel-DAG: %[[w0:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}0.193340182, 0.285152316 // WholeModel-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-0.174680978, -0.367524445 // WholeModel-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}> {_tfl_quant_trait = "fully_quantizable"} -// WholeModel-DAG: "tf.DumpTensor"(%[[output0]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}> +// WholeModel-DAG: "tf.DumpTensor"(%[[output0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}> // WholeModel-DAG: %[[output1:.*]] = "tf.PartitionedCall"(%[[output0]], %[[w1]], %[[b1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"} -// WholeModel-DAG: "tf.DumpTensor"(%[[output1]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> +// WholeModel-DAG: "tf.DumpTensor"(%[[output1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> // WholeModel-DAG: return %[[output1]] // IntPerLayer-LABEL: func @multiple_conv2d @@ -98,12 +98,12 @@ module { // IntPerLayer-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-0.174680978, -0.367524445 // IntPerLayer-DAG: %[[output0_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}> {_tfl_quant_trait = "fully_quantizable"} // IntPerLayer-DAG: %[[output0_unquantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_0}> -// IntPerLayer-DAG: "tf.DumpTensor"(%[[output0_quantized]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}> -// IntPerLayer-DAG: "tf.DumpTensor"(%[[output0_unquantized]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[output0_quantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[output0_unquantized]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}> // IntPerLayer-DAG: %[[output1_quantized:.*]] = "tf.PartitionedCall"(%[[output0_quantized]], %[[w1]], %[[b1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"} // IntPerLayer-DAG: %[[output1_unquantized:.*]] = "tf.PartitionedCall"(%[[output0_quantized]], %[[w1]], %[[b1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_0}> -// IntPerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> -// IntPerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> // IntPerLayer-DAG: return %[[output1_quantized]] // FloatPerLayer-LABEL: func @multiple_conv2d @@ -113,12 +113,12 @@ module { // FloatPerLayer-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-0.174680978, -0.367524445 // FloatPerLayer-DAG: %[[output0_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}> {_tfl_quant_trait = "fully_quantizable"} // FloatPerLayer-DAG: %[[output0_unquantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_0} -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output0_quantized]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"} -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output0_unquantized]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output0_quantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output0_unquantized]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"} // FloatPerLayer-DAG: %[[output1_quantized:.*]] = "tf.PartitionedCall"(%[[output0_unquantized]], %[[w1]], %[[b1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"} // FloatPerLayer-DAG: %[[output1_unquantized:.*]] = "tf.PartitionedCall"(%[[output0_unquantized]], %[[w1]], %[[b1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_0} -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"} -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"} // FloatPerLayer-DAG: return %[[output1_unquantized]] } @@ -146,8 +146,8 @@ module { // WholeModel-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.211145893 // WholeModel-DAG: %[[m0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} // WholeModel-DAG: %[[m1:.*]] = "tf.PartitionedCall"(%[[m0]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} -// WholeModel-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} -// WholeModel-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} +// WholeModel-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} +// WholeModel-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} // WholeModel-DAG: return %[[m1]] // IntPerLayer-LABEL: func @matmul2 @@ -155,12 +155,12 @@ module { // IntPerLayer-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.211145893 // IntPerLayer-DAG: %[[m0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // IntPerLayer-DAG: %[[m0_1:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -// IntPerLayer-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () -// IntPerLayer-DAG: "tf.DumpTensor"(%[[m0_1]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m0_1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () // IntPerLayer-DAG: %[[m1:.*]] = "tf.PartitionedCall"(%[[m0]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // IntPerLayer-DAG: %[[m1_0:.*]] = "tf.PartitionedCall"(%[[m0]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -// IntPerLayer-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () -// IntPerLayer-DAG: "tf.DumpTensor"(%[[m1_0]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m1_0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () // IntPerLayer-DAG: return %[[m1]] : tensor<2x2xf32> // FloatPerLayer-LABEL: func @matmul2 @@ -168,12 +168,12 @@ module { // FloatPerLayer-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.211145893 // FloatPerLayer-DAG: %[[m0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // FloatPerLayer-DAG: %[[m0_1:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m0_1]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m0_1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () // FloatPerLayer-DAG: %[[m1:.*]] = "tf.PartitionedCall"(%[[m0_1]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // FloatPerLayer-DAG: %[[m1_0:.*]] = "tf.PartitionedCall"(%[[m0_1]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m1_0]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m1_0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () // FloatPerLayer-DAG: return %[[m1_0]] : tensor<2x2xf32> } @@ -203,8 +203,8 @@ module { // WholeModel-DAG: %[[pc_0:.*]] = "tf.PartitionedCall"(%arg0, %[[cst_0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} // WholeModel-DAG: %[[sm_0:.*]] = "tf.Softmax"(%[[pc_0]]) {T = "tfdtype$DT_FLOAT"} // WholeModel-DAG: %[[pc_1:.*]] = "tf.PartitionedCall"(%[[sm_0]], %[[cst_1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} -// WholeModel-DAG: "tf.DumpTensor"(%[[pc_0]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} -// WholeModel-DAG: "tf.DumpTensor"(%[[pc_1]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} +// WholeModel-DAG: "tf.DumpTensor"(%[[pc_0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} +// WholeModel-DAG: "tf.DumpTensor"(%[[pc_1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} // WholeModel-DAG: return %[[pc_1]] // IntPerLayer-LABEL: func @matmul2_softmax @@ -212,13 +212,13 @@ module { // IntPerLayer-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.211145893, -0.708605706 // IntPerLayer-DAG: %[[pc_0:.*]] = "tf.PartitionedCall"(%arg0, %[[cst_0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} // IntPerLayer-DAG: %[[pc_1:.*]] = "tf.PartitionedCall"(%arg0, %[[cst_0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2_0} -// IntPerLayer-DAG: "tf.DumpTensor"(%[[pc_0]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} -// IntPerLayer-DAG: "tf.DumpTensor"(%[[pc_1]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} +// IntPerLayer-DAG: "tf.DumpTensor"(%[[pc_0]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} +// IntPerLayer-DAG: "tf.DumpTensor"(%[[pc_1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} // IntPerLayer-DAG: %[[sm_0:.*]] = "tf.Softmax"(%[[pc_0]]) {T = "tfdtype$DT_FLOAT"} // IntPerLayer-DAG: %[[pc_2:.*]] = "tf.PartitionedCall"(%[[sm_0]], %[[cst_1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} // IntPerLayer-DAG: %[[pc_3:.*]] = "tf.PartitionedCall"(%[[sm_0]], %[[cst_1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1_0} -// IntPerLayer-DAG: "tf.DumpTensor"(%[[pc_2]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} -// IntPerLayer-DAG: "tf.DumpTensor"(%[[pc_3]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} +// IntPerLayer-DAG: "tf.DumpTensor"(%[[pc_2]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} +// IntPerLayer-DAG: "tf.DumpTensor"(%[[pc_3]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} // IntPerLayer-DAG: return %[[pc_2]] // FloatPerLayer-LABEL: func @matmul2_softmax @@ -226,13 +226,13 @@ module { // FloatPerLayer-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.211145893, -0.708605706 // FloatPerLayer-DAG: %[[pc_0:.*]] = "tf.PartitionedCall"(%arg0, %[[cst_0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} // FloatPerLayer-DAG: %[[pc_1:.*]] = "tf.PartitionedCall"(%arg0, %[[cst_0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2_0} -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[pc_0]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[pc_1]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[pc_0]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[pc_1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} // FloatPerLayer-DAG: %[[sm_0:.*]] = "tf.Softmax"(%[[pc_1]]) {T = "tfdtype$DT_FLOAT"} // FloatPerLayer-DAG: %[[pc_2:.*]] = "tf.PartitionedCall"(%[[sm_0]], %[[cst_1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} // FloatPerLayer-DAG: %[[pc_3:.*]] = "tf.PartitionedCall"(%[[sm_0]], %[[cst_1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1_0} -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[pc_2]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[pc_3]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[pc_2]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[pc_3]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} // FloatPerLayer-DAG: return %[[pc_3]] } @@ -263,8 +263,8 @@ module { // WholeModel-DAG: %[[axis:.*]] = "tf.Const"() <{value = dense<-1> : tensor} // WholeModel-DAG: %[[m0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} // WholeModel-DAG: %[[m1:.*]] = "tf.PartitionedCall"(%[[m0]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} -// WholeModel-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} -// WholeModel-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} +// WholeModel-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} +// WholeModel-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} // WholeModel-DAG: %[[c:.*]] = "tf.ConcatV2"(%[[m0]], %[[m1]], %[[axis]]) // WholeModel-DAG: return %[[c]] @@ -274,12 +274,12 @@ module { // IntPerLayer-DAG: %[[axis:.*]] = "tf.Const"() <{value = dense<-1> : tensor}> : () -> tensor // IntPerLayer-DAG: %[[m0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // IntPerLayer-DAG: %[[m0_1:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -// IntPerLayer-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () -// IntPerLayer-DAG: "tf.DumpTensor"(%[[m0_1]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m0_1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () // IntPerLayer-DAG: %[[m1:.*]] = "tf.PartitionedCall"(%[[m0]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // IntPerLayer-DAG: %[[m1_0:.*]] = "tf.PartitionedCall"(%[[m0]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -// IntPerLayer-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () -// IntPerLayer-DAG: "tf.DumpTensor"(%[[m1_0]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m1_0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () // IntPerLayer-DAG: %4 = "tf.ConcatV2"(%[[m0]], %[[m1]], %[[axis]]) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor) -> tensor<2x4xf32> // IntPerLayer-DAG: return %4 : tensor<2x4xf32> @@ -289,12 +289,12 @@ module { // FloatPerLayer-DAG: %[[axis:.*]] = "tf.Const"() <{value = dense<-1> : tensor}> : () -> tensor // FloatPerLayer-DAG: %[[m0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // FloatPerLayer-DAG: %[[m0_1:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m0_1]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m0_1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () // FloatPerLayer-DAG: %[[m1:.*]] = "tf.PartitionedCall"(%[[m0_1]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // FloatPerLayer-DAG: %[[m1_0:.*]] = "tf.PartitionedCall"(%[[m0_1]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m1_0]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m1_0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () // FloatPerLayer-DAG: %4 = "tf.ConcatV2"(%1, %[[m1_0]], %[[axis]]) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor) -> tensor<2x4xf32> // FloatPerLayer-DAG: return %4 : tensor<2x4xf32> } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op_stablehlo.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op_stablehlo.mlir index 357a6119fa8b0f..bdb9b320109597 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op_stablehlo.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op_stablehlo.mlir @@ -35,9 +35,9 @@ module { // WholeModel-DAG: %[[b0:.*]] = stablehlo.constant dense<[-0.211145893 // WholeModel-DAG: %[[w0:.*]] = stablehlo.constant dense<{{\[\[}}-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> // WholeModel-DAG: %[[matmul0_q:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_2, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor -// WholeModel-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_2", node_name = "_empty_node"}> : (tensor) -> () +// WholeModel-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_2", node_name = "_empty_node"}> : (tensor) -> () // WholeModel-DAG: %[[matmul1_q:.*]] = "tf.XlaCallModule"(%[[matmul0_q]], %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_1, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor -// WholeModel-DAG: "tf.DumpTensor"(%[[matmul1_q]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_1", node_name = "_empty_node"}> : (tensor) -> () +// WholeModel-DAG: "tf.DumpTensor"(%[[matmul1_q]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_1", node_name = "_empty_node"}> : (tensor) -> () // WholeModel-DAG: return %[[matmul1_q]] : tensor // WholeModel-DAG: func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_2 // WholeModel-DAG: func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_1 @@ -46,13 +46,13 @@ module { // IntPerLayer-DAG: %[[b0:.*]] = stablehlo.constant dense<[-0.211145893 // IntPerLayer-DAG: %[[w0:.*]] = stablehlo.constant dense<{{\[\[}}-0.630731344 // IntPerLayer-DAG: %[[matmul0_q:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_2, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor -// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_2", node_name = "_empty_node"}> : (tensor) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_2", node_name = "_empty_node"}> : (tensor) -> () // IntPerLayer-DAG: %[[matmul0_uq:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_2_0, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2_0"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor -// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_uq]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_2", node_name = "_empty_node"}> : (tensor) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_uq]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_2", node_name = "_empty_node"}> : (tensor) -> () // IntPerLayer-DAG: %[[matmul1_q:.*]] = "tf.XlaCallModule"(%[[matmul0_q]], %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_1, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor -// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul1_q]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_1", node_name = "_empty_node"}> : (tensor) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul1_q]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_1", node_name = "_empty_node"}> : (tensor) -> () // IntPerLayer-DAG: %[[matmul1_uq:.*]] = "tf.XlaCallModule"(%[[matmul0_q]], %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_1_0, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1_0"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor -// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul1_uq]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_1", node_name = "_empty_node"}> : (tensor) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul1_uq]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_1", node_name = "_empty_node"}> : (tensor) -> () // IntPerLayer-DAG: return %[[matmul1_q]] : tensor // IntPerLayer-DAG: func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_2 // IntPerLayer-DAG: func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_1 @@ -63,13 +63,13 @@ module { // FloatPerLayer-DAG: %[[b0:.*]] = stablehlo.constant dense<[-0.211145893 // FloatPerLayer-DAG: %[[w0:.*]] = stablehlo.constant dense<{{\[\[}}-0.630731344 // FloatPerLayer-DAG: %[[matmul0_q:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_2, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_2", node_name = "_empty_node"}> : (tensor) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_2", node_name = "_empty_node"}> : (tensor) -> () // FloatPerLayer-DAG: %[[matmul0_uq:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_2_0, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2_0"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_uq]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_2", node_name = "_empty_node"}> : (tensor) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_uq]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_2", node_name = "_empty_node"}> : (tensor) -> () // FloatPerLayer-DAG: %[[matmul1_q:.*]] = "tf.XlaCallModule"(%[[matmul0_uq]], %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_1, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[matmul1_q]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_1", node_name = "_empty_node"}> : (tensor) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[matmul1_q]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_1", node_name = "_empty_node"}> : (tensor) -> () // FloatPerLayer-DAG: %[[matmul1_uq:.*]] = "tf.XlaCallModule"(%[[matmul0_uq]], %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_1_0, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1_0"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[matmul1_uq]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_1", node_name = "_empty_node"}> : (tensor) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[matmul1_uq]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_1", node_name = "_empty_node"}> : (tensor) -> () // FloatPerLayer-DAG: return %[[matmul1_uq]] : tensor // FloatPerLayer-DAG: func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_2 // FloatPerLayer-DAG: func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_1 @@ -96,7 +96,7 @@ module { // WholeModel-DAG: %[[w0:.*]] = stablehlo.constant dense<{{\[\[}}-0.630731344 // WholeModel-DAG: %[[c0:.*]] = stablehlo.constant dense<1.000000e+00 // WholeModel-DAG: %[[matmul0_q:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]]) <{Sout = [#tf_type.shape<1x3>], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> -// WholeModel-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_fn_1", node_name = "_empty_node"}> : (tensor<1x3xf32>) -> () +// WholeModel-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_fn_1", node_name = "_empty_node"}> : (tensor<1x3xf32>) -> () // WholeModel-DAG: %[[concat:.*]] = stablehlo.concatenate %[[matmul0_q]], %[[c0]], dim = 0 : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x3xf32> // WholeModel-DAG: return %[[concat]] : tensor<2x3xf32> // WholeModel-DAG: func.func private @composite_dot_general_fn_1 @@ -105,9 +105,9 @@ module { // IntPerLayer-DAG: %[[w0:.*]] = stablehlo.constant dense<{{\[\[}}-0.630731344 // IntPerLayer-DAG: %[[c0:.*]] = stablehlo.constant dense<1.000000e+00 // IntPerLayer-DAG: %[[matmul0_q:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]]) <{Sout = [#tf_type.shape<1x3>], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> -// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_fn_1", node_name = "_empty_node"}> : (tensor<1x3xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_fn_1", node_name = "_empty_node"}> : (tensor<1x3xf32>) -> () // IntPerLayer-DAG: %[[matmul0_uq:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]]) <{Sout = [#tf_type.shape<1x3>], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1_0, _original_entry_function = "composite_dot_general_fn_1_0", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> -// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_uq]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_fn_1", node_name = "_empty_node"}> : (tensor<1x3xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_uq]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_fn_1", node_name = "_empty_node"}> : (tensor<1x3xf32>) -> () // IntPerLayer-DAG: %[[concat:.*]] = stablehlo.concatenate %[[matmul0_q]], %[[c0]], dim = 0 : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x3xf32> // IntPerLayer-DAG: return %[[concat]] : tensor<2x3xf32> // IntPerLayer-DAG: func.func private @composite_dot_general_fn_1 @@ -117,9 +117,9 @@ module { // FloatPerLayer-DAG: %[[w0:.*]] = stablehlo.constant dense<{{\[\[}}-0.630731344 // FloatPerLayer-DAG: %[[c0:.*]] = stablehlo.constant dense<1.000000e+00 // FloatPerLayer-DAG: %[[matmul0_q:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]]) <{Sout = [#tf_type.shape<1x3>], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_fn_1", node_name = "_empty_node"}> : (tensor<1x3xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_fn_1", node_name = "_empty_node"}> : (tensor<1x3xf32>) -> () // FloatPerLayer-DAG: %[[matmul0_uq:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]]) <{Sout = [#tf_type.shape<1x3>], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1_0, _original_entry_function = "composite_dot_general_fn_1_0", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> -// FloatPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_uq]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_fn_1", node_name = "_empty_node"}> : (tensor<1x3xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_uq]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_fn_1", node_name = "_empty_node"}> : (tensor<1x3xf32>) -> () // FloatPerLayer-DAG: %[[concat:.*]] = stablehlo.concatenate %[[matmul0_uq]], %[[c0]], dim = 0 : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x3xf32> // FloatPerLayer-DAG: return %[[concat]] : tensor<2x3xf32> // FloatPerLayer-DAG: func.func private @composite_dot_general_fn_1 diff --git a/tensorflow/compiler/mlir/stablehlo/BUILD b/tensorflow/compiler/mlir/stablehlo/BUILD index 81cdce725460fb..5d5342e8a264c4 100644 --- a/tensorflow/compiler/mlir/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/stablehlo/BUILD @@ -1,6 +1,6 @@ -load("//tensorflow:strict.default.bzl", "py_strict_test") -load("//tensorflow:pytype.default.bzl", "pytype_strict_library") load("@local_tsl//tsl:tsl.default.bzl", "tsl_pybind_extension") +load("//tensorflow:pytype.default.bzl", "pytype_strict_library") +load("//tensorflow:strict.default.bzl", "py_strict_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index c1c8966849e4b9..26d5e4d52b41d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1,3 +1,5 @@ +# buildifier: disable=out-of-order-load + load("//tensorflow:strict.default.bzl", "py_strict_library") # copybara:uncomment_begin(google-only) @@ -354,12 +356,15 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", + "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:Parser", "@llvm-project//mlir:SideEffectInterfaces", @@ -400,12 +405,15 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", + "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:Parser", "@llvm-project//mlir:SideEffectInterfaces", @@ -447,12 +455,15 @@ cc_library( "//tensorflow/core/common_runtime:lower_function_call_inline_policy", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", + "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:Parser", "@llvm-project//mlir:SideEffectInterfaces", @@ -521,7 +532,6 @@ cc_library( "ir/tf_saved_model.h", "ir/tf_structs.h", "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.h", - "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", ], includes = ["include"], visibility = ["//visibility:public"], @@ -549,6 +559,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:CallOpInterfacesIncGen", "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:ControlFlowInterfaces", @@ -558,6 +569,7 @@ cc_library( "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:MLProgramDialect", "@llvm-project//mlir:Parser", @@ -672,6 +684,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", ], ) @@ -1366,6 +1379,9 @@ cc_library( deps = [ ":tensorflow", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -1598,6 +1614,7 @@ cc_library( "tensorflow_side_effects", "tensorflow_types", "@llvm-project//mlir:IR", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc index d0a05e45617cf6..5ceda80490f688 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project #include "mlir/Analysis/DataFlowFramework.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -55,9 +54,6 @@ ResourceConstructingOps ResourceConstructingOps::EntryState(Value value) { tf_saved_model::GlobalTensorOp>(func, barg.getArgNumber(), symbol_table); ResourceConstructingOps result(global_tensor); - if (func.getArgAttr(barg.getArgNumber(), kCompositeDevice)) { - result.is_on_composite_device = true; - } return result; } } else if (auto vh = dyn_cast(value.getDefiningOp())) { @@ -75,17 +71,47 @@ ResourceConstructingOps ResourceConstructingOps::join( ResourceConstructingOps ret; ret.ops.insert(lhs.ops.begin(), lhs.ops.end()); ret.ops.insert(rhs.ops.begin(), rhs.ops.end()); - ret.is_on_composite_device = - lhs.is_on_composite_device || rhs.is_on_composite_device; return ret; } void ResourceConstructingOps::print(raw_ostream &os) const { llvm::interleaveComma(ops, os << "["); + os << "]"; +} + +IsComposite::IsComposite(Operation *op) {} + +IsComposite IsComposite::EntryState(MLIRContext *context) { + return IsComposite(); +} + +IsComposite IsComposite::EntryState(Value value) { + IsComposite result; + if (auto barg = value.dyn_cast()) { + if (func::FuncOp func = + dyn_cast(barg.getOwner()->getParentOp())) { + if (func.getArgAttr(barg.getArgNumber(), kCompositeDevice)) { + result.is_on_composite_device = true; + } + return result; + } + } + return result; +} + +IsComposite IsComposite::join(const IsComposite &lhs, const IsComposite &rhs) { + IsComposite ret; + ret.is_on_composite_device = + lhs.is_on_composite_device || rhs.is_on_composite_device; + return ret; +} + +void IsComposite::print(raw_ostream &os) const { if (is_on_composite_device) { - os << " COMPOSITE"; + os << "COMPOSITE"; + } else { + os << "NOT_COMPOSITE"; } - os << "]"; } class ResourceDataflowAnalysis @@ -94,23 +120,32 @@ class ResourceDataflowAnalysis using TensorflowDataflowAnalysis< ResourceConstructingOps>::TensorflowDataflowAnalysis; void visitOperation(Operation *op, ArrayRef operands, - ArrayRef results) override; + ArrayRef results) override { + if (ForwardThroughTFOperation(op, operands, results)) return; + setAllToEntryStates(results); + } ~ResourceDataflowAnalysis() override = default; }; -void ResourceDataflowAnalysis::visitOperation(Operation *op, - ArrayRef operands, - ArrayRef results) { - LLVM_DEBUG(llvm::dbgs() << "ResAn: Visiting operation: " << *op << "\n"); - - if (ForwardThroughTFOperation(op, operands, results)) return; - - setAllToEntryStates(results); -} +class IsCompositeDataflowAnalysis + : public TensorflowDataflowAnalysis { + public: + using TensorflowDataflowAnalysis::TensorflowDataflowAnalysis; + void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override { + if (ForwardThroughTFOperation(op, operands, results)) return; + setAllToEntryStates(results); + } + ~IsCompositeDataflowAnalysis() override = default; +}; void LoadResourceDataflowAnalysis(DataFlowSolver &solver) { solver.load(); } +void LoadIsCompositeDataflowAnalysis(DataFlowSolver &solver) { + solver.load(); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h index 9015b9dc739634..0cf3611af1d20c 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h @@ -46,8 +46,7 @@ struct ResourceConstructingOps { static ResourceConstructingOps EntryState(MLIRContext *context); static ResourceConstructingOps EntryState(Value value); bool operator==(const ResourceConstructingOps &rhs) const { - return ops == rhs.ops && - is_on_composite_device == rhs.is_on_composite_device; + return ops == rhs.ops; } static ResourceConstructingOps join(const ResourceConstructingOps &lhs, @@ -57,13 +56,27 @@ struct ResourceConstructingOps { // The operation(s) which created the resource value. // IR constructs (i.e., GlobalTensorOp) are not const-correct. mutable DenseSet ops; +}; + +struct IsComposite { + explicit IsComposite(Operation *op = nullptr); + static IsComposite EntryState(MLIRContext *context); + static IsComposite EntryState(Value value); + bool operator==(const IsComposite &rhs) const { + return is_on_composite_device == rhs.is_on_composite_device; + } + + static IsComposite join(const IsComposite &lhs, const IsComposite &rhs); + void print(raw_ostream &os) const; bool is_on_composite_device = false; }; typedef dataflow::Lattice ResourceDataflowState; +typedef dataflow::Lattice IsCompositeDataflowState; void LoadResourceDataflowAnalysis(DataFlowSolver &solver); +void LoadIsCompositeDataflowAnalysis(DataFlowSolver &solver); } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD index ddef04d4185e1d..ccf7b0b547ab90 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD @@ -73,6 +73,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core/framework:resource_handle", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/BUILD index a6ecd2d9fbe91f..35336a005eba0c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/BUILD @@ -1,6 +1,6 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 55b68e5de2fb5f..db28242944434e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -2196,6 +2196,73 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- +// Tests inputs to TPUComputation that are tiled in multiple dimensions with +// replicate_on_last_tile_dim set. + +// The following OpSharding is used for TPU computation inputs in below test: +// Proto debug string: +// input 0 +// type: OTHER +// tile_assignment_dimensions: 2 +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// tile_assignment_devices: 2 +// tile_assignment_devices: 3 +// replicate_on_last_tile_dim: true +// Serialized string: +// "\08\03\1A\03\02\01\02\22\04\00\01\02\030\01" +// +// input 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @multi_dimension_tiled_input_replicate_last_dim + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>) + func.func @multi_dimension_tiled_input_replicate_last_dim(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32> + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: %[[CONST_SPLIT_0_DIM:.*]] = "tf.Const"() + // CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]]) + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_0_OUT]]#0, %[[COMPILE]]#1) + // CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_0_OUT]]#0, %[[RI_1]], %[[COMPILE]]#2) + // CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_0_OUT]]#1, %[[COMPILE]]#3) + // CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]] + // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_0_OUT]]#1, %[[COMPILE]]#4) + // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\1A\03\02\01\02\22\04\00\01\02\030\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + func.return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} + +// ----- + // Tests that tiled output with multiple dimension sharding works properly. // The following OpSharding is used for TPU computation outputs in below test: @@ -2278,6 +2345,73 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- +// Tests that tiled output with multiple dimension sharding works properly with +// replicate_on_last_tile_dim set. + +// The following OpSharding is used for TPU computation outputs in below test: +// output 0 +// Proto debug string: +// type: OTHER +// tile_assignment_dimensions: 2 +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// tile_assignment_devices: 2 +// tile_assignment_devices: 3 +// replicate_on_last_tile_dim: true +// Serialized string: +// "\08\03\1A\03\02\01\02\22\04\00\01\02\030\01" +// +// output 1 +// Proto debug string: +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// Serialized string: +// "\08\01\1A\01\01\22\01\00" + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @multi_dimension_tiled_output_replicate_last_dim + func.func @multi_dimension_tiled_output_replicate_last_dim(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32> + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute" + // CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute" + // CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"( + // CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]] + // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"( + // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] + // CHECK: %[[CONST_CONCAT_DIM:.*]] = "tf.Const"() + // CHECK: %[[CONCAT_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#0, %[[PARALLEL_EXECUTE_OUTPUT]]#3 + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\1A\03\02\01\02\22\04\00\01\02\030\01", "\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + func.return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} + +// ----- + // Tests inputs device assignment order is well preserved for tiled input sharding. // The following OpSharding is used for TPU computation inputs in below test: diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index 2bbd90a3aeeebc..3d1cf1bd58fa38 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -60,6 +60,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -337,6 +338,7 @@ cc_library( "@llvm-project//mlir:MLProgramDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -458,9 +460,6 @@ cc_library( "device_index_selector.cc", "drop_while_shape_invariant.cc", "einsum.cc", - "embedding_pipelining.cc", - "embedding_program_key.cc", - "embedding_sequencing.cc", "executor_island_coarsening.cc", "executor_tpuv1_inline_tpu_island.cc", "executor_tpuv1_island_coarsening.cc", @@ -652,11 +651,13 @@ cc_library( "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Rewrite", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc index f265ac68fa5f27..4de43317677f63 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc @@ -145,7 +145,7 @@ bool OnlyOperatesOnCompositeDevices( continue; } auto lattice = - solver.lookupState(arg.get())->getValue(); + solver.lookupState(arg.get())->getValue(); bool is_read = read_array.contains(arg.getOperandNumber()); bool is_update = update_array.contains(arg.getOperandNumber()); // We want the resource operands that are on composite devices to be the @@ -214,7 +214,7 @@ void CollectChainResources( // device-specific (see below). bool resource_is_on_composite_device = false; for (Value value : alias_analysis.GetValuesForResourceId(resource_id)) { - auto lattice = solver.lookupState(value); + auto lattice = solver.lookupState(value); if (lattice) { resource_is_on_composite_device |= lattice->getValue().is_on_composite_device; @@ -604,7 +604,7 @@ void ConvertControlToDataOutputsPass::runOnOperation() { DataFlowSolver solver; solver.load(); solver.load(); - TF::LoadResourceDataflowAnalysis(solver); + TF::LoadIsCompositeDataflowAnalysis(solver); if (failed(solver.initializeAndRun(module))) return signalPassFailure(); // This pass assumes that all functions are suitable for export i.e., each diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD index 5d500453d17fe0..f8e75d9032f3e5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD @@ -26,11 +26,12 @@ cc_library( deps = [ ":runtime_passes", "//tensorflow/compiler/jit:flags_headers", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:bridge_logger", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", + "//tensorflow/compiler/mlir/tensorflow/transforms/sparsecore:sparsecore_passes", "//tensorflow/core:framework", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core/platform:error_payloads", @@ -62,6 +63,7 @@ tf_cc_test( ":lower_cluster_to_runtime_ops", "//tensorflow/compiler/mlir:register_common_dialects", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc index 6f46766a3250fa..a239c7304a0ae0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc @@ -28,6 +28,8 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/runtime_passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" @@ -121,6 +123,7 @@ void CreateNonTPULowerClusterToRuntimeOpsPassPipeline( // TODO(b/306728216): Move this out of the Bridge component and into a Host // runtime component. tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, + std::string bridge_type, tsl::DeviceType device_type, absl::Status status) { if (status.ok()) { @@ -129,11 +132,12 @@ tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, VLOG(2) << error_prefix << " " << status; tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( - device_type.type_string(), /*bridge_version=*/"v2", + bridge_type, + /*bridge_version=*/mlir::TF::kMlirPh1BridgeCounterV2, + device_type.type_string(), /*fallback_enabled=*/false, /*result=*/"failure"); - constexpr char kBridgeComponent[] = "TFXLABridge"; std::string bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_TPU_BRIDGE"; tsl::OkOrSetErrorCounterPayload( @@ -144,7 +148,7 @@ tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_CPU/GPU_BRIDGE"; } - tsl::error_logging::Log(kBridgeComponent, bridge_subcomponent, + tsl::error_logging::Log(mlir::TF::kBridgeComponent, bridge_subcomponent, status.ToString()) .IgnoreError(); @@ -194,10 +198,13 @@ absl::Status RunLowerClusterToRuntimeOpsPassPipeline( module, llvm::StringRef(), &runtime_lowering); } + std::string bridge_type = xla_device_type == DeviceType(DEVICE_TPU_XLA_JIT) + ? mlir::TF::kMlirPh1BridgeCounterReplicated + : mlir::TF::kMlirPh1BridgeCounterNonReplicated; auto result_status = diag_handler.ConsumeStatus(); TF_RETURN_IF_ERROR( RecordIfErrorStatus(/*error_prefix=*/"lower_cluster_to_runtime", - xla_device_type, result_status)); + bridge_type, xla_device_type, result_status)); return absl::OkStatus(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc index 3e3e8db504f1da..1f0cf146203de2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/register_common_dialects.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/env.h" @@ -167,9 +168,11 @@ TEST_F(LowerClusterToRuntimeOpsTest, ErrorsWithBadCluster) { *mlir_module_, DeviceType(DEVICE_TPU_XLA_JIT)) .ok()); - EXPECT_EQ(compilation_status.Delta("XLA_TPU_JIT", "v2", "fallback_disabled", - "failure"), - 1); + EXPECT_EQ( + compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated, + mlir::TF::kMlirPh1BridgeCounterV2, "XLA_TPU_JIT", + "fallback_disabled", "failure"), + 1); } TEST_F(LowerClusterToRuntimeOpsTest, DumpsPipelinePasses) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 9c475f1f9f5281..da89e77cb0862c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -446,13 +446,6 @@ std::unique_ptr> CreateReplicateToIslandPass( std::unique_ptr> CreateReplicaIDToDeviceOrdinalPass(); -// Creates a pass that adds pipelining to a graph that contains device -// accelerated embeddings. The EmbeddingSequencingPass is a temporary fallback -// while developing full pipelining capabilities. -std::unique_ptr> CreateEmbeddingSequencingPass(); -std::unique_ptr> CreateEmbeddingPipeliningPass(); -std::unique_ptr> CreateEmbeddingProgramKeyPass(); - // Creates a pass that creates `tf_executor.island` from a single // `tf_device.parallel_execute` island. std::unique_ptr> CreateParallelExecuteToIslandsPass( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD new file mode 100644 index 00000000000000..bff95d357c885f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD @@ -0,0 +1,123 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/compiler/mlir:__pkg__", + "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", + "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:__pkg__", + "//tensorflow/compiler/mlir/tf2xla/api:__subpackages__", + "//tensorflow/compiler/mlir/tf2xla/internal:__pkg__", + ], + licenses = ["notice"], +) + +gentbl_cc_library( + name = "sparsecore_passes_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=SparseCore", + ], + "sparsecore_passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "sparsecore_passes.td", + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + +cc_library( + name = "sparsecore_passes", + hdrs = [ + "sparsecore_passes.h", + ], + textual_hdrs = [ + "sparsecore_passes.h.inc", + ], + deps = [ + ":embedding_pipelining", + ":embedding_program_key", + ":embedding_sequencing", + ":sparsecore_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "embedding_pipelining", + srcs = ["embedding_pipelining.cc"], + hdrs = [ + "sparsecore_passes.h", + ], + deps = [ + ":sparsecore_passes_inc_gen", + "//tensorflow/compiler/jit:flags_headers", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "embedding_sequencing", + srcs = ["embedding_sequencing.cc"], + hdrs = [ + "sparsecore_passes.h", + ], + deps = [ + ":sparsecore_passes_inc_gen", + "//tensorflow/compiler/jit:flags_headers", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "embedding_program_key", + srcs = ["embedding_program_key.cc"], + hdrs = [ + "sparsecore_passes.h", + ], + deps = [ + ":sparsecore_passes_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@local_xla//xla/mlir_hlo", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc similarity index 99% rename from tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc rename to tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc index ee334b3f032155..0c450126e4e090 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc @@ -157,7 +157,7 @@ return selected_results #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #define GEN_PASS_DEF_EMBEDDINGPIPELININGPASS -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h.inc" static constexpr char kEmbeddingPipelining[] = "_embedding_pipelining"; static constexpr char kEmbeddingPipeliningInlineAttr[] = @@ -1289,7 +1289,7 @@ LogicalResult StartStep0(OpBuilder& builder, Location& loc, func::FuncOp orig_parent_func = callers.backward->getParentOfType(); - std::vector operands = loop_operands_nm0; + const std::vector& operands = loop_operands_nm0; // Input types will be the same as the original loop body. std::vector input_types = GetValueTypes(operands); @@ -1373,7 +1373,7 @@ LogicalResult StartStep1(OpBuilder& builder, Location& loc, func::FuncOp orig_parent_func = callers.backward->getParentOfType(); - std::vector operands = loop_operands_1; + const std::vector& operands = loop_operands_1; // Input types will be the same as the original loop body. std::vector input_types = GetValueTypes(operands); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_program_key.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc similarity index 99% rename from tensorflow/compiler/mlir/tensorflow/transforms/embedding_program_key.cc rename to tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc index a5575ef156ddb9..3e41762feb16c2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_program_key.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { @@ -42,7 +41,7 @@ constexpr char kMiniBatchSplitsAttr[] = "mini_batch_splits"; constexpr char kMiniBatchCsrAttr[] = "mini_batch_in_csr"; #define GEN_PASS_DEF_EMBEDDINGPROGRAMKEYPASS -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h.inc" struct EmbeddingProgramKeyPass : public impl::EmbeddingProgramKeyPassBase { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_sequencing.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc similarity index 98% rename from tensorflow/compiler/mlir/tensorflow/transforms/embedding_sequencing.cc rename to tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc index a77dd6f498a144..7ed29a3ed58cc3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_sequencing.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc @@ -32,6 +32,8 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Casting.h" @@ -40,6 +42,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project @@ -47,17 +50,20 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #define GEN_PASS_DEF_EMBEDDINGSEQUENCINGPASS -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h.inc" static constexpr char kEmbeddingPipelining[] = "_embedding_pipelining"; static constexpr char kEmbeddingForward[] = "forward"; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h new file mode 100644 index 00000000000000..8944745dd3fff9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SPARSECORE_SPARSECORE_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SPARSECORE_SPARSECORE_PASSES_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace TFDevice { + +// For architectures that support accelerated embedding lookups, this pass will +// rewrite the graph to use pipelining for better device utilization. +std::unique_ptr> CreateEmbeddingSequencingPass(); + +// This is a strictly sequential and formally correct fallback option for the +// embedding pipelining pass intended for debugging during pipelining +// development. +std::unique_ptr> CreateEmbeddingPipeliningPass(); + +// Passes in the program key to embedding ops, by moving the embedding ops +// after the _TPUCompileMlir op. +std::unique_ptr> CreateEmbeddingProgramKeyPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_EMBEDDINGSEQUENCINGPASS +#define GEN_PASS_DECL_EMBEDDINGPIPELININGPASS +#define GEN_PASS_DECL_EMBEDDINGPROGRAMKEYPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h.inc" + +} // namespace TFDevice +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SPARSECORE_SPARSECORE_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.td new file mode 100644 index 00000000000000..a9c5981393df6c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.td @@ -0,0 +1,83 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +def EmbeddingPipeliningPass : Pass<"tf-embedding-pipelining", "mlir::ModuleOp"> { + let summary = "Rewrite graph for embedding pipelining"; + let constructor = "TFDevice::CreateEmbeddingPipeliningPass()"; + let description = [{ + For architectures that support accelerated embedding lookups, this pass will + rewrite the graph to use pipelining for better device utilization. + }]; +} + +def EmbeddingSequencingPass : Pass<"tf-embedding-sequencing", "mlir::ModuleOp"> { + let summary = "Rewrite graph for sequential execution of embeddings"; + let constructor = "TFDevice::CreateEmbeddingSequencingPass()"; + let description = [{ + This is a strictly sequential and formally correct fallback option for the + embedding pipelining pass intended for debugging during pipelining + development. + }]; +} + +def EmbeddingProgramKeyPass : Pass<"tf-embedding-program-key", "mlir::func::FuncOp"> { + let summary = "Sets the program key for embedding ops."; + let constructor = "TFDevice::CreateEmbeddingProgramKeyPass()"; + let description = [{ + Passes in the program key to embedding ops. Will move the embedding ops + after a _TPUCompileMlir op if there is no predecessor _TPUCompileMlir op. + Both the embedding op and compile op are assumed to be wrapped in separate + tf_device.launch() ops. This is because the embedding op is head outside + compiled and the compile op is wrapped in launch to execute on host + during TPURewritePass. + + For example, the tf.OpA with the `mini_batch_splits` attribute will be + moved after _TPUCompileMlir and the first input will use the + _TPUCompileMlir program output: + + ```mlir + "tf_device.launch"() ({ + %cst_0 = "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + "tf.OpA"(%cst_0) { mini_batch_splits = ""} : (tensor<1x!tf_type.string>) -> () + tf_device.return + }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> () + %0:2 = "tf_device.launch"() ({ + %compilation_status, %program = "tf._TPUCompileMlir"() { metadata = "...", mlir_module = "..." } : () -> (tensor, tensor<3x!tf_type.string>) + tf_device.return %compilation_status, %program : tensor, tensor<3x!tf_type.string> + }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor, tensor<3x!tf_type.string>) + ``` + + becomes: + + ```mlir + %0:2 = "tf_device.launch"() ({ + %compilation_status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor, tensor<3x!tf_type.string>) + tf_device.return %compilation_status, %program : tensor, tensor<3x!tf_type.string> + }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor, tensor<3x!tf_type.string>) + "tf_device.launch"() ({ + %cst = "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + "tf.OpA"(%0#1) {mini_batch_splits = ""} : (tensor<3x!tf_type.string>) -> () + tf_device.return + }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> () + ``` + }]; + + let dependentDialects = [ + "mhlo::MhloDialect", + "tf_device::TensorFlowDeviceDialect" + ]; +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td index b00e70eb73c4cc..6b53cae7099688 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td @@ -329,73 +329,6 @@ def ReplicaIDToDeviceOrdinalPass : Pass<"tf-replica-id-to-device-ordinal", "mlir }]; } -def EmbeddingPipeliningPass : Pass<"tf-embedding-pipelining", "mlir::ModuleOp"> { - let summary = "Rewrite graph for embedding pipelining"; - let constructor = "TFDevice::CreateEmbeddingPipeliningPass()"; - let description = [{ - For architectures that support accelerated embedding lookups, this pass will - rewrite the graph to use pipelining for better device utilization. - }]; -} - -def EmbeddingProgramKeyPass : Pass<"tf-embedding-program-key", "mlir::func::FuncOp"> { - let summary = "Sets the program key for embedding ops."; - let constructor = "TFDevice::CreateEmbeddingProgramKeyPass()"; - let description = [{ - Passes in the program key to embedding ops. Will move the embedding ops - after a _TPUCompileMlir op if there is no predecessor _TPUCompileMlir op. - Both the embedding op and compile op are assumed to be wrapped in separate - tf_device.launch() ops. This is because the embedding op is head outside - compiled and the compile op is wrapped in launch to execute on host - during TPURewritePass. - - For example, the tf.OpA with the `mini_batch_splits` attribute will be - moved after _TPUCompileMlir and the first input will use the - _TPUCompileMlir program output: - - ```mlir - "tf_device.launch"() ({ - %cst_0 = "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> - "tf.OpA"(%cst_0) { mini_batch_splits = ""} : (tensor<1x!tf_type.string>) -> () - tf_device.return - }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> () - %0:2 = "tf_device.launch"() ({ - %compilation_status, %program = "tf._TPUCompileMlir"() { metadata = "...", mlir_module = "..." } : () -> (tensor, tensor<3x!tf_type.string>) - tf_device.return %compilation_status, %program : tensor, tensor<3x!tf_type.string> - }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor, tensor<3x!tf_type.string>) - ``` - - becomes: - - ```mlir - %0:2 = "tf_device.launch"() ({ - %compilation_status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor, tensor<3x!tf_type.string>) - tf_device.return %compilation_status, %program : tensor, tensor<3x!tf_type.string> - }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor, tensor<3x!tf_type.string>) - "tf_device.launch"() ({ - %cst = "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> - "tf.OpA"(%0#1) {mini_batch_splits = ""} : (tensor<3x!tf_type.string>) -> () - tf_device.return - }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> () - ``` - }]; - - let dependentDialects = [ - "mhlo::MhloDialect", - "tf_device::TensorFlowDeviceDialect" - ]; -} - -def EmbeddingSequencingPass : Pass<"tf-embedding-sequencing", "mlir::ModuleOp"> { - let summary = "Rewrite graph for sequential execution of embeddings"; - let constructor = "TFDevice::CreateEmbeddingSequencingPass()"; - let description = [{ - This is a strictly sequential and formally correct fallback option for the - embedding pipelining pass intended for debugging during pipelining - development. - }]; -} - def ConvertReadonlyReferenceVariablesToResourceVariablesPass : Pass<"tf-readonly-references-to-resources", "mlir::func::FuncOp"> { let summary = "Convert readonly reference variables to resource variables."; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/BUILD b/tensorflow/compiler/mlir/tensorflow/translate/BUILD index 46af8590c8108e..59d7cfd7081106 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/translate/BUILD @@ -122,6 +122,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:TranslateLib", + "@local_tsl//tsl/platform:protobuf", ], alwayslink = 1, ) @@ -287,6 +288,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:TranslateLib", + "@local_tsl//tsl/platform:protobuf", "@local_xla//xla/client:client_library", "@local_xla//xla/client:compile_only_client", "@local_xla//xla/service/cpu:cpu_compiler", diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 6042ae37ee8fa2..523048cd7cd582 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -421,7 +421,7 @@ Status Exporter::AddInstructionNode(Operation* inst) { inst, name, /*ignore_unregistered_attrs=*/false)); UseOriginalFunctionNames(*node_def); - TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(*node_def)); + TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(std::move(*node_def))); DCHECK(node != nullptr); nodes_[inst] = node; return OkStatus(); @@ -436,7 +436,7 @@ bool IsEntryFunctionArg(BlockArgument arg) { Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index, llvm::StringRef name) { TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name)); - TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(*node_def)); + TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(std::move(*node_def))); args_[arg] = node; return OkStatus(); } @@ -455,7 +455,7 @@ Status Exporter::AddFetchNode(FuncOp function, mlir::tf_executor::FetchOp fetch, GetReturnNode(function, operand_and_idx.value(), operand_and_idx.index(), names.empty() ? "" : names[operand_and_idx.index()])); - TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(*node_def)); + TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(std::move(*node_def))); return_nodes.push_back(node); } return OkStatus(); @@ -687,15 +687,6 @@ Status Exporter::ConvertLibFunction( TF_RETURN_IF_ERROR( GraphToFunctionDef(*sub_graph, function_name, control_ret, &func_def)); - // The node defs in FunctionDef might contain debug info which was added - // by the GraphToFunctionDef method. We should remove it if we don't want - // to export them to avoid failing the roundtrip test. - if (!configs.export_debug_info) { - for (auto& node_def : *func_def.mutable_node_def()) { - node_def.clear_experimental_debug_info(); - } - } - // Checks for gradient attribute. If present converts the gradient function // and populates the GradientDef. auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName(); @@ -831,17 +822,6 @@ StatusOr> ConvertMlirToGraphdef( auto graphdef = std::make_unique(); graph->ToGraphDef(graphdef.get()); - if (!configs.export_library) graphdef->clear_library(); - if (!configs.export_shapes) { - for (auto& node_def : *graphdef->mutable_node()) { - node_def.mutable_attr()->erase("shape"); - } - } - if (!configs.export_debug_info) { - for (auto& node_def : *graphdef->mutable_node()) { - node_def.clear_experimental_debug_info(); - } - } return graphdef; } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h index 00fd5b7de6aa4d..fca039c2601636 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h @@ -102,12 +102,6 @@ struct GraphImportConfig { }; struct GraphExportConfig { - // Whether to export shape attribute for the NodeDefs in the GraphDef. - bool export_shapes = true; - // Whether to export library field in the GraphDef. - bool export_library = true; - // Whether to export debug original node name in the GraphDef. - bool export_debug_info = true; // Whether to export the entry function to function library instead of the // graph. bool export_entry_func_to_flib = false; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index 57b0d0e2ff2389..eb9bf3db34106d 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tsl/platform/protobuf.h" namespace mlir { using tsl::Status; @@ -152,7 +153,7 @@ static LogicalResult MlirToGraphTranslateFunction(ModuleOp module, // Print the graph to the output after going through GraphDef conversion. // The DumpGraphToFile would do this anyway so just skip straight to it. graph->ToGraphDef(graphdef.get()); - output << graphdef->DebugString(); + output << tsl::LegacyUnredactedDebugString(*graphdef); return success(); } @@ -167,7 +168,6 @@ static LogicalResult MlirToGraphdefTranslateFunction( ModuleOp module, llvm::raw_ostream& output) { if (!module) return failure(); - // TODO(fengliuai): Add exporter flags. tensorflow::GraphExportConfig confs; confs.export_entry_func_to_flib = export_entry_func_to_flib; confs.export_original_tf_func_name = export_original_tf_func_name; @@ -179,7 +179,7 @@ static LogicalResult MlirToGraphdefTranslateFunction( return mlir::failure(); } - output << graphdef_or.value()->DebugString(); + output << tsl::LegacyUnredactedDebugString(*graphdef_or.value()); return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc index 08cbb51a576760..856db032e501ae 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tsl/platform/protobuf.h" namespace mlir { static mlir::Operation* ExtractOnlyOp(mlir::ModuleOp module) { @@ -61,7 +62,7 @@ static LogicalResult MlirToTfNodeDef(ModuleOp module, return failure(); } - output << node_def_or.value()->DebugString(); + output << tsl::LegacyUnredactedDebugString(*node_def_or.value()); return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h index b50135c9bdfac3..5a99806d4295f3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h @@ -121,6 +121,18 @@ inline constexpr llvm::StringRef kDynamicArgIndexAttr = "_dynamic_arg_index"; inline constexpr llvm::StringRef kParallelExecAnnotation = "_parallel_execution_ids"; +// Logging + +// Name of component for error logging. This name is fixed and required to +// enable logging. +inline const char kBridgeComponent[] = "TFXLABridge"; +inline const char kMlirPh1BridgeCounterReplicated[] = "replicated"; +inline const char kMlirPh1BridgeCounterNonReplicated[] = "nonreplicated"; +inline const char kMlirPh1BridgeCounterV1[] = "v1"; +inline const char kMlirPh1BridgeCounterV2[] = "v2"; +inline const char kMlirPh1BridgeCounterTpu[] = "tpu"; +inline const char kMlirPh1BridgeCounterNonTpu[] = "cpu/gpu"; + // Copies attributes that satisfy the given predicate from `from` to `to`. template void CopyAttributes(Operation *from, Operation *to, Predicate P) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc index bb474b1413f7ac..2efd63b29b04ef 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc @@ -125,7 +125,7 @@ TEST(DumpCrashReproducerTest, RoundtripDumpAndReadValid) { EXPECT_TRUE(mlir::MlirOptMain(output_stream->os(), std::move(input_file), registry, mlir::MlirOptMainConfig{} - .splitInputFile(false) + .splitInputFile("") .verifyDiagnostics(false) .verifyPasses(false) .allowUnregisteredDialects(false) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index f9dc740cee1aae..f01a3f0e09d19b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -61,6 +61,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { @@ -395,12 +396,12 @@ Status ConvertAttributes( if (auto symbol_ref = attr.dyn_cast()) { TF_RETURN_IF_ERROR( ConvertAttribute(symbol_ref.cast(), &value)); - func_call_attrs[string(name)] = value; + func_call_attrs[string(name)] = std::move(value); continue; } if (auto func_attr = attr.dyn_cast()) { TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, remove_ref_type, &value)); - func_call_attrs[string(name)] = value; + func_call_attrs[string(name)] = std::move(value); continue; } if (attr.isa()) { @@ -434,13 +435,14 @@ Status ConvertAttributes( TF_RET_CHECK(name_tokens.size() <= 2); auto it = func_call_attrs.find(name_tokens[0]); if (it == func_call_attrs.end()) { - (*values)[string(name)] = value; + (*values)[string(name)] = std::move(value); } else { - (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = value; + (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = + std::move(value); } } - for (const auto& it : func_call_attrs) { - (*values)[it.first] = it.second; + for (auto& it : func_call_attrs) { + (*values)[it.first] = std::move(it.second); } return OkStatus(); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index 58adaa41349b14..ea76adb284b7e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -15,10 +15,15 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" +#include +#include #include #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" @@ -152,16 +157,21 @@ mlir::LogicalResult HandleTileShardedInputs( // are created such that input data is sharded in row major order. // Split nodes at ith depth from the original input node represent nodes // that split the input data at i-th dimension. - const auto& dimension_splits = input_sharding.tile_assignment_dimensions(); - for (const auto& num_splits_and_index : llvm::enumerate(dimension_splits)) { - const int num_splits = num_splits_and_index.value(); - const int dimension_index = num_splits_and_index.index(); - if (num_splits == 1) continue; + auto dimension_to_splits_map = + GetDimensionIndicesAndNumSplitsFromSharding(input_sharding); + if (!dimension_to_splits_map.ok()) { + LOG(ERROR) << dimension_to_splits_map.status(); + return mlir::failure(); + } + + for (const auto& dimension_and_num_splits : *dimension_to_splits_map) { + const int dimension = dimension_and_num_splits.first; + const int num_splits = dimension_and_num_splits.second; // Creates root split op. if (split_ops_for_tiled_input.empty()) { mlir::TF::SplitOp root_split_op; - auto result = CreateSplitOp(num_splits, dimension_index, location, + auto result = CreateSplitOp(num_splits, dimension, location, original_source, builder, &root_split_op); if (mlir::failed(result)) return mlir::failure(); @@ -176,7 +186,7 @@ mlir::LogicalResult HandleTileShardedInputs( for (auto parent_split_output_value : split_op.getResults()) { mlir::TF::SplitOp child_split_op; auto result = - CreateSplitOp(num_splits, dimension_index, location, + CreateSplitOp(num_splits, dimension, location, parent_split_output_value, builder, &child_split_op); if (mlir::failed(result)) return mlir::failure(); @@ -188,12 +198,21 @@ mlir::LogicalResult HandleTileShardedInputs( } // `split_ops_for_tiled_input` now includes final split nodes - // from which sharded data will be fed into TPUExcute ops -- sorted by + // from which sharded data will be fed into TPUExecute ops -- sorted by // row major order. + tiled_inputs->clear(); tiled_inputs->reserve(input_sharding.tile_assignment_devices_size()); - for (auto split_op : split_ops_for_tiled_input) - tiled_inputs->append(split_op.getResults().begin(), - split_op.getResults().end()); + for (auto split_op : split_ops_for_tiled_input) { + for (auto split_op_output : split_op.getResults()) { + int64_t repeat_count = + input_sharding.replicate_on_last_tile_dim() + ? *input_sharding.tile_assignment_dimensions().rbegin() + : 1; + for (int64_t i = 0; i < repeat_count; ++i) { + tiled_inputs->push_back(split_op_output); + } + } + } return mlir::success(); } @@ -205,6 +224,29 @@ bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) { } // namespace +absl::StatusOr> GetDimensionIndicesAndNumSplitsFromSharding( + const xla::OpSharding& sharding) { + int64_t tensor_tile_rank = sharding.tile_assignment_dimensions_size(); + if (sharding.replicate_on_last_tile_dim()) { + tensor_tile_rank--; + } + + std::map dimension_to_splits_map; + for (int dim_index = 0; dim_index < tensor_tile_rank; ++dim_index) { + if (sharding.tile_assignment_dimensions(dim_index) > 1) { + dimension_to_splits_map.emplace( + dim_index, sharding.tile_assignment_dimensions(dim_index)); + } + } + + if (dimension_to_splits_map.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Arg has unnecessary tiled sharding: ", sharding.DebugString())); + } + + return dimension_to_splits_map; +} + int GetDimsFromXLAShardingTiled(const xla::OpSharding& xla_sharding) { return xla_sharding.tile_assignment_dimensions_size() - (xla_sharding.replicate_on_last_tile_dim() ? 1 : 0) - @@ -478,15 +520,25 @@ mlir::LogicalResult GetTileShardedOutputsToMerge( const xla::OpSharding& sharding = output_sharding_config[cluster_func_output_index]; outputs_to_merge->reserve(sharding.tile_assignment_devices_size()); - for (const auto logical_device_id : sharding.tile_assignment_devices()) { + for (const auto& core_id_and_index : + llvm::enumerate(sharding.tile_assignment_devices())) { + auto core_id = core_id_and_index.value(); + auto tile_index = core_id_and_index.index(); + + int last_tile_dim_size = *sharding.tile_assignment_dimensions().rbegin(); + if (sharding.replicate_on_last_tile_dim() && + tile_index % last_tile_dim_size != 0) { + continue; + } + int region_output_index; - auto status = LookupClusterToCoreIndex( - location, cluster_to_core_index, logical_device_id, - cluster_func_output_index, ®ion_output_index); + auto status = LookupClusterToCoreIndex(location, cluster_to_core_index, + core_id, cluster_func_output_index, + ®ion_output_index); if (failed(status)) return mlir::failure(); const auto output_from_logical_device = - new_parallel_execute.GetRegionOutputs( - cluster_idx + logical_device_id)[region_output_index]; + new_parallel_execute.GetRegionOutputs(cluster_idx + + core_id)[region_output_index]; outputs_to_merge->emplace_back(output_from_logical_device); } @@ -518,12 +570,18 @@ mlir::LogicalResult HandleTileShardedOutputs( // devices to a single replica output. const xla::OpSharding& sharding = output_sharding_config[cluster_func_output_index]; - int concat_dimension = sharding.tile_assignment_dimensions_size() - 1; - for (auto num_splits : llvm::reverse(sharding.tile_assignment_dimensions())) { - if (num_splits == 1) { - --concat_dimension; - continue; - } + + auto dimension_to_splits_map = + GetDimensionIndicesAndNumSplitsFromSharding(sharding); + if (!dimension_to_splits_map.ok()) { + LOG(ERROR) << dimension_to_splits_map.status(); + return mlir::failure(); + } + + for (auto it = dimension_to_splits_map->rbegin(); + it != dimension_to_splits_map->rend(); ++it) { + int concat_dimension = it->first; + int num_splits = it->second; llvm::SmallVector new_outputs; new_outputs.reserve(num_splits); @@ -539,7 +597,6 @@ mlir::LogicalResult HandleTileShardedOutputs( } std::swap(new_outputs, outputs_to_merge); - --concat_dimension; } assert(outputs_to_merge.size() == 1); @@ -552,33 +609,35 @@ mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( const mlir::TensorType cluster_func_output_type, const xla::OpSharding& output_sharding, mlir::Type* tiled_logical_computation_type) { - auto new_output_shape = - llvm::to_vector<4>(cluster_func_output_type.getShape()); - for (const auto& dimension_and_output_splits : - llvm::enumerate(output_sharding.tile_assignment_dimensions())) { - const auto dimension_index = dimension_and_output_splits.index(); - const auto output_splits = dimension_and_output_splits.value(); - const auto output_shape = cluster_func_output_type.getShape(); - - if (output_shape[dimension_index] == mlir::ShapedType::kDynamic) { + const auto output_shape = cluster_func_output_type.getShape(); + auto new_output_shape = llvm::to_vector<4>(output_shape); + auto dimension_to_splits_map = + GetDimensionIndicesAndNumSplitsFromSharding(output_sharding); + if (!dimension_to_splits_map.ok()) { + LOG(ERROR) << dimension_to_splits_map.status(); + return mlir::failure(); + } + + for (const auto& dimension_and_output_splits : *dimension_to_splits_map) { + const auto dimension = dimension_and_output_splits.first; + const auto output_splits = dimension_and_output_splits.second; + + if (output_shape[dimension] == mlir::ShapedType::kDynamic) { *tiled_logical_computation_type = cluster_func_output_type; break; } - auto output_shape_at_dim = - cluster_func_output_type.getShape()[dimension_index]; - if (output_shape_at_dim % output_splits != 0) { + if (output_shape[dimension] % output_splits != 0) { mlir::emitError( location, llvm::formatv("incorrect output sharding received. " "{0}-th dimension of the output must be " "evenly divisible by {1}, got dimension " "shape {2}", - dimension_index, output_splits, output_shape_at_dim)); + dimension, output_splits, output_shape[dimension])); } - new_output_shape[dimension_index] = - output_shape[dimension_index] / output_splits; + new_output_shape[dimension] = output_shape[dimension] / output_splits; } *tiled_logical_computation_type = mlir::RankedTensorType::get( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h index 6295be3776416e..ab22eb978214ad 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ +#include #include +#include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -122,6 +124,9 @@ bool IsSplitSharding(const xla::OpSharding& sharding); // REPLICATED type and replicated OTHER type. bool IsReplicatedSharding(const xla::OpSharding& sharding); +// Returns a map of dimension indices and number of splits for tiled sharding. +absl::StatusOr> GetDimensionIndicesAndNumSplitsFromSharding( + const xla::OpSharding& sharding); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index 38094bf7067d1b..53a65bd3ae3662 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -77,16 +77,28 @@ tf_cc_test( srcs = ["compile_mlir_util_test.cc"], deps = [ ":compile_mlir_util_no_tf_dialect_passes", + "//tensorflow/compiler/jit:xla_compile_util", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", + "//tensorflow/compiler/mlir/utils:array_container_utils", + "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/monitoring:cell_reader", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/client:xla_builder", ], ) @@ -182,6 +194,7 @@ cc_library( ], deps = [ ":tf_dialect_to_executor", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:bridge_logger", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", @@ -220,6 +233,7 @@ tf_cc_test( deps = [ ":cluster_tf", "//tensorflow/compiler/mlir:register_common_dialects", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", "//tensorflow/core/lib/monitoring:cell_reader", "//tensorflow/core/platform:resource_loader", @@ -229,7 +243,6 @@ tf_cc_test( "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/lib/monitoring:test_utils", "@local_tsl//tsl/platform:status", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc index 09209d8673524c..38c11ec857f072 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" @@ -60,10 +61,6 @@ using mlir::func::FuncOp; namespace { -// Name of component for error logging. This name is fixed and required to -// enable logging. -constexpr char kBridgeComponent[] = "TFXLABridge"; - void CreateReplicatedBridgePipelineV1(OpPassManager &pm) { pm.addPass(mlir::tf2xla::internal::CreateInferenceMetricsPass()); @@ -152,10 +149,12 @@ tensorflow::Status RecordStatusIfError(const std::string error_prefix, } tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( - /*device_type=*/"tpu", /*bridge_version=*/"v1", + /*bridge_type=*/mlir::TF::kMlirPh1BridgeCounterReplicated, + /*bridge_version=*/mlir::TF::kMlirPh1BridgeCounterV1, + /*device_type*/ mlir::TF::kMlirPh1BridgeCounterTpu, /*fallback_enabled=*/is_in_fallback_enabled_mode, /*result=*/"failure"); - tsl::error_logging::Log(kBridgeComponent, + tsl::error_logging::Log(mlir::TF::kBridgeComponent, "TFXLA_PHASE_ONE_MLIR_TPU_V1_COMPAT_BRIDGE", status.ToString()) .IgnoreError(); @@ -221,7 +220,9 @@ tensorflow::Status RunSessionTf2xlaClusteringBridge( RunClusteringPipelineOnSubmodule(module, is_in_fallback_enabled_mode)); tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( - /*device_type=*/"tpu", /*bridge_version=*/"v1", + /*bridge_type=*/mlir::TF::kMlirPh1BridgeCounterReplicated, + /*bridge_version=*/mlir::TF::kMlirPh1BridgeCounterV1, + /*device_type*/ mlir::TF::kMlirPh1BridgeCounterTpu, /*n_fallback_enabled*/ is_in_fallback_enabled_mode, /*result=*/"success"); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc index 44eafb25f579c8..e674989d2174ba 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "tensorflow/compiler/mlir/register_common_dialects.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/resource_loader.h" #include "tsl/lib/core/status_test_util.h" @@ -84,8 +85,11 @@ TEST_F(SessionClusterTensorflowDialectTest, ClustersTf) { TF_EXPECT_OK( RunSessionTf2xlaClusteringBridge(*mlir_module_, /*is_in_fallback_enabled_mode=*/false)); - EXPECT_EQ( - compilation_status.Delta("tpu", "v1", "fallback_disabled", "success"), 1); + EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated, + mlir::TF::kMlirPh1BridgeCounterV1, + mlir::TF::kMlirPh1BridgeCounterTpu, + "fallback_disabled", "success"), + 1); } TEST_F(SessionClusterTensorflowDialectTest, FailsWithMultipleSubmodules) { @@ -98,8 +102,11 @@ TEST_F(SessionClusterTensorflowDialectTest, FailsWithMultipleSubmodules) { /*is_in_fallback_enabled_mode=*/false) .ok()); - EXPECT_EQ( - compilation_status.Delta("tpu", "v1", "fallback_disabled", "failure"), 1); + EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated, + mlir::TF::kMlirPh1BridgeCounterV1, + mlir::TF::kMlirPh1BridgeCounterTpu, + "fallback_disabled", "failure"), + 1); } } // namespace diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc index 20fff0cc549d0f..59fb22e87eab58 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc @@ -989,7 +989,8 @@ Status CompileGraphToXlaHlo( } absl::StatusOr> GraphToModule( - const Graph& graph, llvm::ArrayRef control_rets, + bool unconditionally_use_set_output_shapes, const Graph& graph, + llvm::ArrayRef control_rets, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, mlir::MLIRContext* context) { mlir::DialectRegistry registry; @@ -1004,20 +1005,27 @@ absl::StatusOr> GraphToModule( // the shape inference pass is run early in the pass pipeline, shape inference // during import is not necessary. config.enable_shape_inference = false; + // Some graphs may require _output_shapes (an unregistered attribute) + // to override shapes. It is unfortunately not always set correctly so only + // do it optionally. + config.unconditionally_use_set_output_shapes = + unconditionally_use_set_output_shapes; return ConvertGraphToMlir(graph, debug_info, flib_def, config, context); } Status BuildHloFromGraph( const Graph& graph, xla::XlaBuilder& builder, mlir::MLIRContext& mlir_context, llvm::ArrayRef xla_params, - std::vector& returns, llvm::ArrayRef args, - llvm::ArrayRef control_rets, llvm::StringRef device_type, - const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, + std::vector& returns, bool unconditionally_use_output_shapes, + llvm::ArrayRef args, llvm::ArrayRef control_rets, + llvm::StringRef device_type, const FunctionLibraryDefinition& flib_def, + const GraphDebugInfo& debug_info, llvm::MutableArrayRef> custom_legalization_passes) { TF_ASSIGN_OR_RETURN( mlir::OwningOpRef module, - GraphToModule(graph, control_rets, flib_def, debug_info, &mlir_context)); + GraphToModule(unconditionally_use_output_shapes, graph, control_rets, + flib_def, debug_info, &mlir_context)); return BuildHloFromModule(module.get(), builder, xla_params, returns, args, device_type, custom_legalization_passes); } @@ -1034,7 +1042,8 @@ Status CompileGraphToXlaHlo( mlir::MLIRContext context; TF_ASSIGN_OR_RETURN( mlir::OwningOpRef module, - GraphToModule(graph, control_rets, flib_def, debug_info, &context)); + GraphToModule(/*unconditionally_use_set_output_shapes=*/false, graph, + control_rets, flib_def, debug_info, &context)); return CompileGraphToXlaHlo( module.get(), args, device_type, use_tuple_args, enable_op_fallback, /*use_return_tuple=*/true, shape_determination_fns, compilation_result, diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h index 3f6e446ca28fd9..aaccd39a3db398 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h @@ -191,7 +191,7 @@ Status CompileGraphToXlaHlo( // Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata // and stores them in CompilationResult. ABSL_DEPRECATED( - "Use v1/compile_tf_graph.h::CompileTensorflowGraphToHloinstead.") + "Use v1/compile_tf_graph.h::CompileTensorflowGraphToHlo instead.") Status CompileGraphToXlaHlo( const Graph& graph, llvm::ArrayRef args, llvm::ArrayRef control_rets, llvm::StringRef device_type, @@ -206,14 +206,17 @@ Status CompileGraphToXlaHlo( // XlaBuilder. This function adds HLO to a larger HLO computation, so // HLO-level inputs are supplied, and HLO-level outputs are produced. // xla_params is the HLO-level inputs and returns is the HLO-level outputs. +// If unconditionally_use_output_shapes is true then the unregistered +// attribute _output_shapes is always used to set the output shapes of the ops. ABSL_DEPRECATED( - "Use v1/compile_tf_graph.h::CompileTensorflowGraphToHloinstead.") + "Use v1/compile_tf_graph.h::CompileTensorflowGraphToHlo instead.") Status BuildHloFromGraph( const Graph& graph, xla::XlaBuilder& builder, mlir::MLIRContext& mlir_context, llvm::ArrayRef xla_params, - std::vector& returns, llvm::ArrayRef args, - llvm::ArrayRef control_rets, llvm::StringRef device_type, - const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, + std::vector& returns, bool unconditionally_use_output_shapes, + llvm::ArrayRef args, llvm::ArrayRef control_rets, + llvm::StringRef device_type, const FunctionLibraryDefinition& flib_def, + const GraphDebugInfo& debug_info, llvm::MutableArrayRef> custom_legalization_passes = {}); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc index d7d8e8e4f4e894..62fbf4bb94381f 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc @@ -15,21 +15,37 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" +#include +#include #include #include #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/jit/xla_compile_util.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/client/xla_builder.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace tensorflow { namespace { @@ -182,5 +198,59 @@ TEST(LegalizeMlirTest, LegalizesModuleWithDynamicShape) { EXPECT_TRUE(status.ok()); } +absl::StatusOr> BuildOpGraphWithOutputShapes() { + DataType data_type = DT_INT32; + std::initializer_list dims = {2, 3, 4, 5}; + Tensor tensor(data_type, TensorShape(dims)); + for (int i = 0; i < 2 * 3 * 4 * 5; ++i) { + tensor.flat()(i) = i; + } + + NodeDef node; + auto builder = NodeDefBuilder("some_node", "Const") + .Attr("dtype", data_type) + .Attr("value", tensor); + // Create a bad output shape attr. + AttrValue shape_attr; + TensorShapeProto* shape_proto = shape_attr.mutable_list()->add_shape(); + shape_proto->add_dim()->set_size(1); + builder.Attr("_output_shapes", shape_attr); + + TF_RETURN_IF_ERROR(builder.Finalize(&node)); + + return CreateSingleOpGraph(node, {}, {DataType::DT_INT32}); +} + +absl::Status BuildHloFromGraph(Graph& graph, bool use_output_shapes) { + xla::XlaBuilder builder( + ::testing::UnitTest::GetInstance()->current_test_info()->name()); + mlir::MLIRContext mlir_context; + llvm::SmallVector xla_params; + std::vector returns(1); + return BuildHloFromGraph(graph, builder, mlir_context, xla_params, returns, + use_output_shapes, /*args=*/{}, + /*control_rets=*/{}, DEVICE_TPU, + FunctionLibraryDefinition(OpRegistry::Global()), + /*debug_info=*/{}, + /*custom_legalization_passes=*/{}); +} + +TEST(CompileMlirUtil, UsesCorrectOriginalShapeWithoutOutputShapes) { + TF_ASSERT_OK_AND_ASSIGN(auto graph, BuildOpGraphWithOutputShapes()); + + auto build_result = BuildHloFromGraph(*graph, /*use_output_shapes=*/false); + TF_ASSERT_OK(build_result); +} + +TEST(CompileMlirUtil, UsesIncorrectOutputShapesWhenPresent) { + TF_ASSERT_OK_AND_ASSIGN(auto graph, BuildOpGraphWithOutputShapes()); + + auto build_result = BuildHloFromGraph(*graph, /*use_output_shapes=*/true); + ASSERT_FALSE(build_result.ok()); + EXPECT_THAT(build_result.message(), + HasSubstr("op operand type 'tensor<2x3x4x5xi32>' and result type " + "'tensor<1xi32>' are cast incompatible")); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD index a92239e8dbba69..545203ad20ea23 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD @@ -119,12 +119,11 @@ cc_library( ], deps = [ ":device_type_proto_cc", - ":tf_dialect_to_executor", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", - "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops", "//tensorflow/compiler/mlir/tf2xla/internal:clustering_bridge_passes", "//tensorflow/compiler/mlir/tf2xla/internal:logging_hooks", "//tensorflow/core:framework", @@ -133,7 +132,6 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:stacktrace", "//tensorflow/core/platform:status", - "//tensorflow/core/tpu:tpu_defs", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@llvm-project//llvm:Support", @@ -143,7 +141,6 @@ cc_library( "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:error_logging", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", ], ) @@ -159,6 +156,7 @@ tf_cc_test( ":cluster_tf", "//tensorflow/compiler/mlir:register_common_dialects", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", "//tensorflow/core/lib/monitoring:cell_reader", "//tensorflow/core/platform:resource_loader", diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc index 23480374032aaa..41df5eb0750459 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/device_type.pb.h" @@ -52,8 +53,6 @@ using mlir::OpPassManager; using mlir::PassManager; using mlir::func::FuncOp; -constexpr char kBridgeComponent[] = "TFXLABridge"; - // Run the TF XLA Bridge based on the input pipeline, which can be either TPU // bridge pipeline or non TPU bridge pipeline. tensorflow::Status RunTFXLABridge( @@ -114,6 +113,7 @@ tensorflow::Status RunTFXLABridge( tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, bool fallback_enabled, + std::string bridge_type, std::string device_type, absl::Status status) { if (status.ok()) { @@ -122,7 +122,7 @@ tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, VLOG(2) << error_prefix << " " << status; tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( - device_type, /*bridge_version=*/"v2", + /*bridge_type*/ bridge_type, /*bridge_version=*/"v2", device_type, /*fallback_enabled=*/fallback_enabled, /*result=*/"failure"); @@ -135,7 +135,7 @@ tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_CPU/GPU_BRIDGE"; } - tsl::error_logging::Log(kBridgeComponent, bridge_subcomponent, + tsl::error_logging::Log(mlir::TF::kBridgeComponent, bridge_subcomponent, status.ToString()) .IgnoreError(); @@ -162,8 +162,9 @@ void CreateReplicatedClusteringPipelineV2(OpPassManager &pm) { tensorflow::Status RunFunctionTf2xlaClusteringBridge( ModuleOp module, bool is_supported_by_replicated_brige, bool is_in_fallback_enabled_mode, llvm::StringRef module_name) { - std::string device_type_filter = - is_supported_by_replicated_brige ? "tpu" : "cpu/gpu"; + std::string device_type = is_supported_by_replicated_brige + ? mlir::TF::kMlirPh1BridgeCounterTpu + : mlir::TF::kMlirPh1BridgeCounterNonTpu; VLOG(2) << (is_supported_by_replicated_brige ? "Replicated" : "NonReplicated") @@ -186,14 +187,17 @@ tensorflow::Status RunFunctionTf2xlaClusteringBridge( }, module_name, /*dump_prefix=*/"tf_xla_bridge_v2_nonreplicated"); + std::string bridge_type = is_supported_by_replicated_brige + ? mlir::TF::kMlirPh1BridgeCounterReplicated + : mlir::TF::kMlirPh1BridgeCounterNonReplicated; // TODO(b/317798386): add is_supported_by_replicated_brige as a filter. TF_RETURN_IF_ERROR(RecordIfErrorStatus( /*error_prefix=*/"clustering_v2", is_in_fallback_enabled_mode, - device_type_filter, clustering_status)); + bridge_type, device_type, clustering_status)); // TODO(b/317798386): add is_supported_by_replicated_brige as a filter. tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( - device_type_filter, /*bridge_version=*/"v2", + bridge_type, /*bridge_version=*/"v2", device_type, /*fallback_enabled=*/is_in_fallback_enabled_mode, /*result=*/"success"); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc index c4a96702533c49..a5f64a91cd8cb4 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/register_common_dialects.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/resource_loader.h" #include "tsl/lib/core/status_test_util.h" @@ -94,8 +95,11 @@ TEST_F(FunctionClusterTensorflowDialectTest, ClustersTfReplicatedBridge) { FuncOp main = mlir_module_->lookupSymbol("main"); ASSERT_TRUE(main); - EXPECT_EQ( - compilation_status.Delta("tpu", "v2", "fallback_disabled", "success"), 1); + EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated, + mlir::TF::kMlirPh1BridgeCounterV2, + mlir::TF::kMlirPh1BridgeCounterTpu, + "fallback_disabled", "success"), + 1); } TEST_F(FunctionClusterTensorflowDialectTest, @@ -118,8 +122,11 @@ TEST_F(FunctionClusterTensorflowDialectTest, }); EXPECT_TRUE(has_cluster_op); - EXPECT_EQ( - compilation_status.Delta("tpu", "v2", "fallback_disabled", "success"), 1); + EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated, + mlir::TF::kMlirPh1BridgeCounterV2, + mlir::TF::kMlirPh1BridgeCounterTpu, + "fallback_disabled", "success"), + 1); } TEST_F(FunctionClusterTensorflowDialectTest, ClustersTFNonReplicatedBridge) { @@ -135,7 +142,10 @@ TEST_F(FunctionClusterTensorflowDialectTest, ClustersTFNonReplicatedBridge) { ASSERT_TRUE(main); EXPECT_EQ( - compilation_status.Delta("cpu/gpu", "v2", "fallback_disabled", "success"), + compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterNonReplicated, + mlir::TF::kMlirPh1BridgeCounterV2, + mlir::TF::kMlirPh1BridgeCounterNonTpu, + "fallback_disabled", "success"), 1); } @@ -148,8 +158,11 @@ TEST_F(FunctionClusterTensorflowDialectTest, LogsFallbackMode) { *mlir_module_, /*is_supported_by_replicated_brige*/ true, /*is_in_fallback_enabled_mode=*/true)); - EXPECT_EQ( - compilation_status.Delta("tpu", "v2", "fallback_enabled", "success"), 1); + EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated, + mlir::TF::kMlirPh1BridgeCounterV2, + mlir::TF::kMlirPh1BridgeCounterTpu, + "fallback_enabled", "success"), + 1); } } // namespace diff --git a/tensorflow/compiler/mlir/tf2xla/internal/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/BUILD index 246481c5cab7db..7e937d2ce49f8b 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/BUILD @@ -187,6 +187,7 @@ cc_library( "//tensorflow/compiler/jit:flags_headers", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", + "//tensorflow/compiler/mlir/tensorflow/transforms/sparsecore:sparsecore_passes", "//tensorflow/compiler/mlir/tf2xla/internal/passes:clustering_passes", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc index 603d928daf9032..e289934b69fbe0 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h" #include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD index b1d7863e860aa6..4c6f68a3419656 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -1,7 +1,7 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_portable") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( diff --git a/tensorflow/compiler/mlir/tf2xla/tests/BUILD b/tensorflow/compiler/mlir/tf2xla/tests/BUILD index 97bb01c30d1855..c68c485954de1b 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/tests/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index 28a459ccff2eac..b76b52c9fd774a 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -2,10 +2,10 @@ # TF2XLA Bridge transforms load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -196,6 +196,7 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@local_xla//xla/mlir_hlo", "@stablehlo//:chlo_ops", + "@stablehlo//:stablehlo_ops", ], ) @@ -286,6 +287,7 @@ cc_library( "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla:shape_util", "@local_xla//xla:side_effect_util", @@ -293,13 +295,15 @@ cc_library( "@local_xla//xla/client:padding", "@local_xla//xla/client:sharding_builder", "@local_xla//xla/mlir_hlo", - "@local_xla//xla/mlir_hlo:chlo_legalize_to_hlo", "@local_xla//xla/mlir_hlo:convert_op_folder", + "@local_xla//xla/mlir_hlo:mhlo_passes", + "@local_xla//xla/mlir_hlo:type_conversion", "@local_xla//xla/stream_executor/tpu:c_api_conversions", "@local_xla//xla/stream_executor/tpu:tpu_api", "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", "@stablehlo//:chlo_ops", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.cc index 3d46b98d9bac90..816b9a5e8b7706 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -34,6 +35,7 @@ ConversionTarget GetDefaultLegalConversionTargets(MLIRContext& mlir_context, if (legalize_chlo) { target.addIllegalDialect(); + target.addIllegalDialect(); } else { target.addLegalDialect(); } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc index 7336c8fe625447..d3c9ff7e8bd157 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.h" @@ -45,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/rewriters.h" +#include "xla/mlir_hlo/mhlo/utils/type_conversion.h" #include "tensorflow/core/lib/monitoring/counter.h" namespace mlir { @@ -203,9 +205,9 @@ LogicalResult legalizeTF(Operation *op, bool legalize_chlo, // Populate with CHLO->HLO lowerings to account for TF ops legalized to // CHLO first. + stablehlo::StablehloToHloTypeConverter hlo_converter; if (legalize_chlo) { - chlo::populateDecomposeChloPatterns(context, &patterns); - chlo::populateChloBroadcastingPatterns(context, &patterns); + chlo::populateChloToHloPatterns(context, &hlo_converter, &patterns); } // ConstantLike op is convenient to create splat constants, but is // canonicalized to plain HLO constant if statically shaped. Add the diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td index 4c3f664af9cb83..19c31018185c82 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td @@ -45,12 +45,16 @@ def LegalizeTF : Pass<"xla-legalize-tf", "ModuleOp"> { ]; let constructor = "mlir::mhlo::createLegalizeTFPass()"; - let dependentDialects = ["arith::ArithDialect, chlo::ChloDialect", - "mhlo::MhloDialect", - "quant::QuantizationDialect", - "shape::ShapeDialect", - "func::FuncDialect", - "sparse_tensor::SparseTensorDialect"]; + let dependentDialects = [ + "arith::ArithDialect", + "chlo::ChloDialect", + "func::FuncDialect", + "mhlo::MhloDialect", + "quant::QuantizationDialect", + "shape::ShapeDialect", + "sparse_tensor::SparseTensorDialect", + "stablehlo::StablehloDialect" + ]; } def LegalizeTFCollective : Pass<"xla-legalize-tf-collective", "ModuleOp"> { diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 2c49198be7bad8..1ce45fe7345c11 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow//compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" @@ -24,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/runtime_passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mlprogram_util.h" @@ -35,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tosa/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/tosa/tfl_passes.h" #include "tensorflow/compiler/mlir/tosa/transforms/passes.h" -#include "xla/mlir/framework/ir/xla_framework.h" #include "xla/mlir/framework/transforms/passes.h" #include "xla/mlir_hlo/lhlo/transforms/passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -69,6 +70,8 @@ int main(int argc, char **argv) { tensorflow::RegisterGraphOptimizationPasses(); tensorflow::RegisterMlProgramPasses(); mlir::TFTPU::registerRuntimeLoweringPasses(); + mlir::TFDevice::registerSparseCorePasses(); + tensorflow::tfrt_compiler::RegisterTPULowerClusterToRuntimeOpsPassPipeline(); tensorflow::tfrt_compiler:: RegisterNonTPULowerClusterToRuntimeOpsPassPipeline(); diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index 04cd4282e5c451..cfb2a9b0b86a35 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -1,3 +1,8 @@ +load( + "@llvm-project//mlir:tblgen.bzl", + "gentbl_cc_library", + "td_library", +) load("//tensorflow:strict.default.bzl", "py_strict_library") load( "//tensorflow:tensorflow.bzl", @@ -5,13 +10,8 @@ load( "tf_cc_test", ) load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable", "tf_py_strict_test", "tf_python_pybind_extension") -load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") -load( - "@llvm-project//mlir:tblgen.bzl", - "gentbl_cc_library", - "td_library", -) load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( @@ -114,6 +114,7 @@ cc_library( "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:SideEffectInterfaces", @@ -163,6 +164,7 @@ cc_library( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SCFDialect", diff --git a/tensorflow/compiler/mlir/tfr/build_defs.bzl b/tensorflow/compiler/mlir/tfr/build_defs.bzl index 090523ce7da3e9..e9dd5e9178080b 100644 --- a/tensorflow/compiler/mlir/tfr/build_defs.bzl +++ b/tensorflow/compiler/mlir/tfr/build_defs.bzl @@ -1,8 +1,8 @@ """BUILD extension for TF composition project.""" +load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.default.bzl", "tf_custom_op_py_library") -load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library") def gen_op_libraries( name, diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index e2157630ceb1b5..21cdf1203a3554 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -145,6 +145,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:core_runtime_opdefs", @@ -165,6 +166,7 @@ cc_library( "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_opdefs", "@llvm-project//mlir:IR", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:core_runtime_opdefs", @@ -180,6 +182,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -252,7 +255,9 @@ cc_library( "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:core_runtime_opdefs", @@ -444,6 +449,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:SideEffectInterfaces", ], ) diff --git a/tensorflow/compiler/mlir/tfrt/ir/BUILD b/tensorflow/compiler/mlir/tfrt/ir/BUILD index 80257d4812ecd3..68e9624e118453 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/BUILD @@ -28,6 +28,7 @@ cc_library( deps = [ ":tfrt_fallback_opdefs_inc_gen", "@llvm-project//mlir:IR", + "@llvm-project//mlir:SideEffectInterfaces", ], ) @@ -50,8 +51,12 @@ cc_library( ":tfrt_fallback_common", ":tfrt_fallback_opdefs", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:compiler_tfrt_op_interfaces", @@ -78,6 +83,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:SideEffectInterfaces", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:core_runtime_opdefs", "@tf_runtime//:tensor_opdefs", @@ -251,6 +257,7 @@ cc_library( ":tfrt_fallback_opdefs", ":tfrt_gpu_opdefs_inc_gen", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD index 4b2b0576430bd1..ce69fa85189423 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD @@ -58,6 +58,7 @@ cc_library( ":mlrt_ops_inc_gen", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", ], ) @@ -166,6 +167,8 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_side_effects", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Transforms", "@tf_runtime//:compiler_tfrt_op_interfaces", "@tf_runtime//:compiler_tfrt_traits", @@ -183,5 +186,6 @@ cc_library( ":tf_mlrt_tpu_ops_inc_gen", "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//mlir:IR", + "@llvm-project//mlir:SideEffectInterfaces", ], ) diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td index 7fbc42ad3db93f..6ff38dda69bd85 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td @@ -35,6 +35,24 @@ def CreateOp: TensorflowMlrt_Op<"createop", []> { let assemblyFormat = "attr-dict"; } +def ConstOp: TensorflowMlrt_Op<"constop", []> { + let summary = "The tf_mlrt ConstOp"; + + let description = [{ + The ConstOp creates a constant tensorflow::Tensor from serialized proto. + }]; + + let arguments = (ins + StrAttr:$tensor_proto + ); + + let results = (outs + TFTensorType:$result + ); + + let assemblyFormat = "attr-dict"; +} + def ExecuteOp : TensorflowMlrt_Op<"executeop", []> { let summary = "The Fallback ExecuteOp"; let description = [{ @@ -427,7 +445,7 @@ def AsyncWhileOp : TensorflowMlrt_Op<"async_while", [Pure]> { }]; } -def IfrtLoadVariableOp: TensorflowMlrt_Op<"ifrt_load_variable", []> { +def IfrtLoadVariableOp: TensorflowMlrt_Op<"ifrt_load_variable", [Pure]> { let summary = "Loads a variable tensor as an IFRT array for mlrt"; let description = [{ @@ -458,5 +476,31 @@ def IfrtLoadVariableOp: TensorflowMlrt_Op<"ifrt_load_variable", []> { ); } +def IfrtRestoreVariableOp: TensorflowMlrt_Op<"ifrt_restore_variable", []> { + let summary = "Restore variable tensors"; + let description = [{ + This is the MLRT version of tf.IfrtRestoreVariableOp. + + This Op is similar to a combination of RestoreV2 and AssignVariable Op, but + this Op's execution is asynchronous. + + This Op is specific to MLRT runtime and is not a stable interface for + serialization. + + This Op will restore the tensors asynchronously and allow the runtime to look + for them. + The runtime shall handle the possibility that the tensors are not ready when requested + because the tensors are loaded asynchronously. + + }]; + + let arguments = (ins + TFTensorType:$prefix, + TFTensorType:$tensor_names, + TFTensorType:$shape_and_slices, + Variadic:$var_handles, + TypeArrayAttr: $restored_dtypes + ); +} #endif diff --git a/tensorflow/compiler/mlir/tfrt/tests/BUILD b/tensorflow/compiler/mlir/tfrt/tests/BUILD index cdae75eea036d2..1efb1ac7a16322 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "if_oss") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD b/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD index c9b64b7b4fb625..cfe04b0689155d 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "if_oss", "tf_cc_test") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir b/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir index e6d5aec8285a0b..cf14af8f3d35f8 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir @@ -96,6 +96,30 @@ func.func @hoist_const_return(%arg: tensor {tf_saved_model.index_path = ["i module attributes {tf_saved_model.semantics} { +// Test not hoisting `tf.BatchFunction`. + +// CHECK-LABEL: func @_tfrt_resource_init +// CHECK: [[const:%.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> {device = "/CPU:0"} : () -> tensor<1xi32> +// CHECK: "tf._TfrtSetResource"([[const]]) <{index = 0 : i64}> {device = "/CPU:0"} : (tensor<1xi32>) -> () + +// CHECK-LABEL: func.func private @func_with_batch_function +func.func private @func_with_batch_function() -> tensor<*xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "StatefulPartitionedCall:0"}} { + // CHECK: "tf._TfrtGetResource"() + %cst = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> {device = "/CPU:0"} : () -> tensor<1xi32> + // CHECK: "tf.BatchFunction" + %0 = "tf.BatchFunction"(%cst) <{allowed_batch_sizes = [1], batch_timeout_micros = 5000 : i64, batching_queue = "", container = "", enable_large_batch_splitting = true, f = @_batched, low_priority_allowed_batch_sizes = [], low_priority_batch_timeout_micros = 0 : i64, low_priority_max_batch_size = 0 : i64, low_priority_max_enqueued_batches = 0 : i64, max_batch_size = 1 : i64, max_enqueued_batches = 1 : i64, num_batch_threads = 1 : i64, operandSegmentSizes = array, shared_name = "batch_function___inference_signature_wrapper_fn_with_defaults_36"}> {device = "/CPU:0"} : (tensor<1xi32>) -> tensor<*xi32> + return %0 : tensor<*xi32> +} +func.func private @_batched(%arg0: tensor<1xi32>) -> tensor<1xi32> { + return %arg0 : tensor<1xi32> +} + +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // Test hoisting write side-effect ops. // CHECK-LABEL: func @_tfrt_resource_init diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/lower_to_ifrt_restore_variable.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/lower_to_ifrt_restore_variable.mlir index 46f7f52195deca..5052694566de89 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/ifrt/lower_to_ifrt_restore_variable.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/lower_to_ifrt_restore_variable.mlir @@ -25,6 +25,32 @@ module { } } +// ----- +// single variable: VarHandleOp is before RestoreV2 + +// CHECK-LABEL: func.func @varhandle_before_restore() { +// CHECK-NEXT: [[PREFIX:%.*]] = "tf.Const"() <{value = dense<"restore_ariables"> : tensor}> : () -> tensor +// CHECK-NEXT: [[SLICE:%.*]] = "tf.Const"() <{value = dense<""> : tensor<1x!tf_type.string>}> : () -> tensor<1x!tf_type.string> +// CHECK-NEXT: [[NAME:%.*]] = "tf.Const"() <{value = dense<"y"> : tensor<1x!tf_type.string>}> : () -> tensor<1x!tf_type.string> +// CHECK-NEXT: [[HANDLEY:%.*]] = "tf.VarHandleOp"() <{container = "", shared_name = "y"}> : () -> tensor>> +// CHECK-NEXT: "tf.IfrtRestoreVariableOp"([[PREFIX]], [[NAME]], [[SLICE]], [[HANDLEY]]) +// CHECK-SAME: {restored_dtypes = [f32]} +// CHECK-NOT: "tf.RestoreV2" +// CHECK-NEXT: return + +module { + func.func @varhandle_before_restore() { + %cst = "tf.Const"() <{value = dense<"restore_ariables"> : tensor}> : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<""> : tensor<1x!tf_type.string>}> : () -> tensor<1x!tf_type.string> + %cst_1 = "tf.Const"() <{value = dense<"y"> : tensor<1x!tf_type.string>}> : () -> tensor<1x!tf_type.string> + %1 = "tf.VarHandleOp"() <{container = "", shared_name = "y"}> : () -> tensor>> + %0 = "tf.RestoreV2"(%cst, %cst_1, %cst_0): (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<3x1xf32> + "tf.AssignVariableOp"(%1, %0) : (tensor>>, tensor<3x1xf32>) -> () + return + } +} + + // ----- // multiple variables diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_identity_propagation.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_identity_propagation.mlir new file mode 100644 index 00000000000000..6ff8613283472d --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_identity_propagation.mlir @@ -0,0 +1,38 @@ +// RUN: tf-tfrt-opt %s -tf-identity-propagation -canonicalize | FileCheck %s + +// CHECK-LABEL: func @identity +// CHECK-SAME: (%[[ARG0:.*]]: tensor) +func.func @identity(%arg0: tensor) -> tensor { + // CHECK-NOT: "tf.Identity" + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor + // CHECK: return %[[ARG0]] + func.return %0 : tensor +} + +// CHECK-LABEL: func @identity_terminator +// CHECK-SAME: (%[[ARG0:.*]]: tensor) +func.func @identity_terminator(%arg0: tensor) -> (tensor<*xi32>, tensor) { + // CHECK: %[[IDENTITY:.*]] = "tf.Identity" + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor<*xi32> + // CHECK-NOT: "tf.Identity" + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // CHECK: return %[[IDENTITY]], %[[ARG0]] + func.return %0, %1 : tensor<*xi32>, tensor +} + +// CHECK-LABEL: func @xla_sharding +func.func @xla_sharding(%arg0: tensor) -> tensor { + // CHECK: %[[OUTPUT:.*]] = "tf.Identity" + %0 = "tf.Identity"(%arg0) {_XlaSharding = ""} : (tensor) -> tensor + // CHECK: return %[[OUTPUT]] + func.return %0 : tensor +} + +// CHECK-LABEL: func @identity_n +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) +func.func @identity_n(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + // CHECK-NOT: "tf.IdentityN" + %0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor, tensor) -> (tensor, tensor) + // CHECK: return %[[ARG0]], %[[ARG1]] + func.return %0#0, %0#1 : tensor, tensor +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_pruning.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_pruning.mlir new file mode 100644 index 00000000000000..3055438d5c468d --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_pruning.mlir @@ -0,0 +1,25 @@ +// RUN: tf-tfrt-opt -tf-restore-pruning %s | FileCheck %s + +// CHECK-LABEL: func.func @prune_unused_restore +func.func @prune_unused_restore() { + %cst = "tf.Const"() <{value = dense<"restore_ariables"> : tensor}> : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<""> : tensor<1x!tf_type.string>}> : () -> tensor<1x!tf_type.string> + %cst_1 = "tf.Const"() <{value = dense<"y"> : tensor<1x!tf_type.string>}> : () -> tensor<1x!tf_type.string> + // CHECK-NOT: tf.RestoreV2 + %0 = "tf.RestoreV2"(%cst, %cst_1, %cst_0): (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<3x1xf32> + %1 = "tf.VarHandleOp"() <{container = "", shared_name = "y"}> : () -> tensor>> + return +} + + +// CHECK-LABEL: func.func @used_restore_remains +func.func @used_restore_remains() { + %cst = "tf.Const"() <{value = dense<"restore_ariables"> : tensor}> : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<""> : tensor<1x!tf_type.string>}> : () -> tensor<1x!tf_type.string> + %cst_1 = "tf.Const"() <{value = dense<"y"> : tensor<1x!tf_type.string>}> : () -> tensor<1x!tf_type.string> + // CHECK: tf.RestoreV2 + %0 = "tf.RestoreV2"(%cst, %cst_1, %cst_0): (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<3x1xf32> + %1 = "tf.VarHandleOp"() <{container = "", shared_name = "y"}> : () -> tensor>> + "tf.AssignVariableOp"(%1, %0) : (tensor>>, tensor<3x1xf32>) -> () + return +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD b/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD index 8d49d08b1025f8..44fc2c0f6945b4 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "if_oss", "tf_cc_test") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/BUILD b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/BUILD index 90bd835edb2828..1d2b470c8adb91 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "if_oss") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/inline.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/inline.mlir index de2a29c017df30..88bc197c8e88d6 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/inline.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/inline.mlir @@ -29,7 +29,7 @@ func.func @while_body_if(%cond: tensor, %x: tensor, %y: tensor, %z: // CHECK-LABEL: func @while_test_if // CHECK-SAME: -> !tf_mlrt.tensor func.func @while_test_if(%cond: tensor, %x: tensor, %y: tensor) -> (tensor) { - // CHECK: [[CONST:%.*]] = tf_mlrt.executeop + // CHECK: [[CONST:%.*]] = tf_mlrt.constop {tensor_proto = "\08\03\12\00"} %cst = "tf.Const"() {__op_key = 2: i32, device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor // Predicate should be inlined. // CHECK-NEXT: tf_mlrt.predicate diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir index eb2e0587364d6e..3cb879dabe97f7 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir @@ -236,7 +236,7 @@ func.func @while_body_add2(%arg0: tensor) -> tensor { // CHECK-LABEL: func @while_test() // CHECK-SAME: -> !tf_mlrt.tensor func.func @while_test() -> (tensor) { - // CHECK: [[CONST:%.*]] = tf_mlrt.executeop + // CHECK: [[CONST:%.*]] = tf_mlrt.constop %0 = "tf.Const"() {__op_key = 4: i32, device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor // CHECK: [[pred_res:%.*]] = call @"while_cond_lt9/tf_mlrt_predicate"([[CONST]]) : (!tf_mlrt.tensor) -> i1 // CHECK: [[while_res:%.*]]:2 = mlrt.while @@ -353,8 +353,7 @@ func.func @main(%input0: tensor) -> tensor { {callee = @main_stream_0} : (tensor, !mlrt.promise) -> !mlrt.async_handle - // CHECK: [[const:%.*]] = tf_mlrt.executeop - // CHECK-SAME: Const + // CHECK: [[const:%.*]] = tf_mlrt.const %const = "tf.Const"() {__op_key = 1: i32, value = dense<2> : tensor} : () -> tensor // CHECK: [[b:%.*]] = tf_mlrt.await [[futures]] @@ -476,3 +475,24 @@ func.func @ifrt_load_variable_test() -> () { func.return } +// ----- + +// Test lowering of IfrtRestoreVariableOp + +// CHECK-LABEL: func @ifrt_restore_variable_test +func.func @ifrt_restore_variable_test() -> () { + // CHECK-NEXT: [[PREFIX:%.*]] = tf_mlrt.constop + %cst = "tf.Const"() {__op_key = 0: i32, value = dense<"restore_ariables"> : tensor} : () -> tensor + // CHECK-NEXT: [[SLICE:%.*]] = tf_mlrt.constop + %cst_0 = "tf.Const"() {__op_key = 1: i32, value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + // CHECK-NEXT: [[NAME:%.*]] = tf_mlrt.constop + %cst_1 = "tf.Const"() {__op_key = 2: i32, value = dense<["y"]> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + // CHECK-NEXT: [[HANDLE:%.*]] = tf_mlrt.executeop + %handle = "tf.VarHandleOp"() {__op_key = 3: i32, container = "x", shared_name = "y"} : () -> tensor>> + // CHECK-NEXT: "tf_mlrt.ifrt_restore_variable"([[PREFIX]], [[NAME]], [[SLICE]], [[HANDLE]]) {restored_dtypes = [f32]} + "tf.IfrtRestoreVariableOp"(%cst, %cst_1, %cst_0, %handle) {restored_dtypes = [f32]} : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor>>) -> () + // CHECK-NEXT: return + func.return +} + + diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD index 60823b2abba41c..70c2235a20a104 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "if_oss") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/BUILD b/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/BUILD index 1065a5fc1a682a..bbcc6e963788c9 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow:tensorflow.bzl", "tf_cc_shared_test") load("@tf_runtime//tools:mlir_to_bef.bzl", "glob_tfrt_lit_tests", "mlir_to_bef") +load("//tensorflow:tensorflow.bzl", "tf_cc_shared_test") # copybara:uncomment load("//third_party/tf_runtime_google/cpp_tests:gen_tests.bzl", "tfrt_cc_test_and_strict_benchmark") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index 8e15b9fcfee8ac..6ef5c011d0a11d 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -17,6 +17,7 @@ package_group( "//learning/brain/tfrt/cpp_tests/...", "//learning/pathways/serving/runtime/...", "//learning/pathways/serving/tests/...", + "//learning/brain/tfrt/mlir/mlrt/application/pathways/compiler/...", # Allow visibility from the mlir language server. "//learning/brain/mlir/mlir_lsp_server/...", ]), @@ -49,14 +50,27 @@ cc_library( ], ) +cc_library( + name = "ifrt_types", + srcs = [], + hdrs = ["ifrt_types.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "tf_ifrt_passes", srcs = [ "lower_to_ifrt_restore_variable.cc", "rewrite_cluster_to_ifrt_call.cc", "sink_variable_as_named_array.cc", + "tf_identity_propagation.cc", "tf_ifrt_passes.cc", "tf_restore_merging.cc", + "tf_restore_pruning.cc", "tf_restore_splitting.cc", ], hdrs = [ @@ -108,6 +122,7 @@ cc_library( hdrs = ["tf2hlo.h"], deps = [ ":ifrt_constants", + ":ifrt_types", "//tensorflow/compiler/jit:xla_cpu_jit", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h new file mode 100644 index 00000000000000..c64672cdb10e69 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h @@ -0,0 +1,33 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_TYPES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_TYPES_H_ + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { +namespace ifrt_serving { + +struct DtypeAndShape { + tensorflow::DataType dtype; + tensorflow::TensorShape shape; +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_TYPES_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/lower_to_ifrt_restore_variable.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/lower_to_ifrt_restore_variable.cc index 9effab181c1566..7c0fa364b593a7 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/lower_to_ifrt_restore_variable.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/lower_to_ifrt_restore_variable.cc @@ -49,16 +49,15 @@ class LowerToIfrtRestoreVariablePass void runOnOperation() override { mlir::ModuleOp module = getOperation(); - mlir::WalkResult walk_result = - module.walk([&](mlir::TF::RestoreV2Op restore_op) { - if (mlir::failed(RewriteRestore(restore_op))) { - return mlir::WalkResult::interrupt(); - } - return mlir::WalkResult::advance(); - }); - - if (walk_result.wasInterrupted()) { - return signalPassFailure(); + std::vector restore_ops; + module.walk([&](mlir::TF::RestoreV2Op restore_op) { + restore_ops.push_back(restore_op); + }); + + for (const auto& restore_op : restore_ops) { + if (mlir::failed(RewriteRestore(restore_op))) { + return signalPassFailure(); + } } } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td index 20bddd75722c63..7cdc5576ae5465 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td @@ -59,6 +59,20 @@ def LowerToIfrtRestoreVariablePass: Pass<"lower-to-ifrt-restore-variable", "mlir let constructor = "CreateLowerToIfrtRestoreVariablePass()"; } +def TfRestorePruningPass + : Pass<"tf-restore-pruning", "mlir::func::FuncOp"> { + let summary = "Prune unused`tf.RestoreV2` ops"; + + let description = [{ + This pass prune unused `tf.RestoreV2` op. A typical use case is to combine + `TfRestoreSplittingPass`, this pass and `TfRestoreMergingPass` in sequence + so that the un-used restored tensors are not read into host memory. + }]; + + let constructor = "CreateTfRestorePruningPass()"; +} + + def TfRestoreSplittingPass : Pass<"tf-restore-splitting", "mlir::func::FuncOp"> { let summary = "Splits `tf.RestoreV2` ops"; @@ -89,4 +103,29 @@ def TfRestoreMergingPass : Pass<"tf-restore-merging", "mlir::func::FuncOp"> { }]; let constructor = "CreateTfRestoreMergingPass()"; -} \ No newline at end of file +} + +def TfIdentityPropagationPass + : Pass<"tf-identity-propagation", "mlir::func::FuncOp"> { + let summary = "Propagates inputs of no-op identity ops to their outputs"; + + let description = [{ + This pass finds identity ops that are no-op and propagates their inputs + directly to outputs so that identity ops can be skipped. + + One example of identity ops that are not no-op is identity ops with XLA + sharding annotation. Since some models use identity ops with `_XlaSharding` + attributes to change output sharding, this pass doesn't propagate the inputs + of such identity ops in order to preserve the sharding changes. + + This pass is useful to make sure that ineffective identity ops don't affect + the graph partitioning. For example, in a pipelined model, if there is a CPU + identity op between two TPU computation stages (which sometimes happens + because TensorFlow inserts it), this will unnecessarily route the + intermediate tensors through the CPU device. By forwarding the inputs of the + identity op directly to its outputs, we can avoid such inefficiency. + }]; + + let constructor = "CreateTfIdentityPropagationPass()"; +} + diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc index 312761a3ba06d7..a0b01ba1ffc3f7 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -127,6 +128,7 @@ absl::StatusOr GetCompileMetadata( // Create a default device assignment if one is not given by the model. if (!metadata.has_device_assignment()) { + // TODO(b/316068010): integrate core selection. TF_ASSIGN_OR_RETURN( auto device_assignment, ifrt_client.GetDefaultDeviceAssignment( diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h index 74fa271401f547..fec9bbb2c740e7 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/python/ifrt/client.h" #include "tensorflow/core/framework/tensor.h" @@ -31,11 +32,6 @@ limitations under the License. namespace tensorflow { namespace ifrt_serving { -struct DtypeAndShape { - tensorflow::DataType dtype; - tensorflow::TensorShape shape; -}; - struct Tf2HloResult { mlir::OwningOpRef mlir_hlo_module; tensorflow::tpu::TPUCompileMetadataProto compile_metadata; diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_identity_propagation.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_identity_propagation.cc new file mode 100644 index 00000000000000..873838727c9d33 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_identity_propagation.cc @@ -0,0 +1,88 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { + +#define GEN_PASS_DEF_TFIDENTITYPROPAGATIONPASS +#define GEN_PASS_DECL_TFIDENTITYPROPAGATIONPASS +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep + +constexpr absl::string_view kXlaShardingAttr = "_XlaSharding"; + +bool IsTerminator(mlir::Operation* op) { + return op->hasTrait(); +} + +class TfIdentityPropagationPass + : public impl::TfIdentityPropagationPassBase { + public: + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + + func.walk([](mlir::TF::IdentityOp identity) { + // Don't propagate inputs of identity ops with sharding annotation since + // identity ops are sometimes used to change output sharding. + if (identity->hasAttr(kXlaShardingAttr)) { + return; + } + // Identity outputs to terminator ops (e.g., `func.return`) cannot be + // replaced unless input/output types are exactly the same. Doing so may + // cause mismatch between the enclosing region's return type and the + // terminator's arg type. + const bool same_type = + identity.getInput().getType() == identity.getOutput().getType(); + identity.getOutput().replaceUsesWithIf( + identity.getInput(), [&](mlir::OpOperand& operand) { + return same_type || !IsTerminator(operand.getOwner()); + }); + }); + + func.walk([](mlir::TF::IdentityNOp identity_n) { + if (identity_n->hasAttr(kXlaShardingAttr)) { + return; + } + for (auto [input, output] : + llvm::zip(identity_n.getInput(), identity_n.getOutput())) { + const bool same_type = input.getType() == output.getType(); + output.replaceUsesWithIf(input, [&](mlir::OpOperand& operand) { + return same_type || !IsTerminator(operand.getOwner()); + }); + } + }); + } +}; + +} // namespace + +std::unique_ptr> +CreateTfIdentityPropagationPass() { + return std::make_unique(); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc index 53bd55cc0d2799..9737c681d28aa8 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc @@ -70,6 +70,14 @@ void AddClusterToIfrtRuntimeOpsPassPipeline(OpPassManager& pm, pm.addNestedPass( mlir::TF::CreateCanonicalizeCompileAndReplicateAttributesPass()); + pm.addNestedPass(CreateTfIdentityPropagationPass()); + + pm.addNestedPass(CreateTfRestoreSplittingPass()); + pm.addNestedPass(CreateTfRestorePruningPass()); + pm.addNestedPass(CreateTfRestoreMergingPass()); + + pm.addPass(CreateLowerToIfrtRestoreVariablePass()); + pm.addPass(CreateRewriteClusterToIfrtCallPass()); // Sink VarHandle with ReadVariableOp: subsequent SinkVariableAsNamedArrayPass diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h index 3835a77f04f93c..93713fbdc13646 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h @@ -44,6 +44,14 @@ CreateTfRestoreSplittingPass(); std::unique_ptr> CreateTfRestoreMergingPass(); +// Creates a pass that propagates inputs of no-op identity ops to their outputs. +std::unique_ptr> +CreateTfIdentityPropagationPass(); + +// Creates a pass that prunes unused `tf.RestoreV2` ops. +std::unique_ptr> +CreateTfRestorePruningPass(); + // Creates a pass that lower `tf.RestoreVariableOp` to // `tf.IfrtRestoreVariableOp`. std::unique_ptr> diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_pruning.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_pruning.cc new file mode 100644 index 00000000000000..6491be3f7151fa --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_pruning.cc @@ -0,0 +1,52 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { + +#define GEN_PASS_DEF_TFRESTOREPRUNINGPASS +#define GEN_PASS_DECL_TFRESTOREPRUNINGPASS +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep + +// Prune unused RestoreV2 Op. +class TfRestorePruningPass + : public impl::TfRestorePruningPassBase { + public: + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + func.walk([&](mlir::TF::RestoreV2Op restore) { + if (restore.use_empty()) { + restore.erase(); + } + }); + } +}; + +} // namespace + +std::unique_ptr> +CreateTfRestorePruningPass() { + return std::make_unique(); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc b/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc index 085d77df441e2e..17e3d8be95204d 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc @@ -115,6 +115,10 @@ bool CanHoist(const llvm::DenseSet &read_only_vars, // return ops should not be hoisted. if (op->mightHaveTrait()) return false; + // Fixes a corner case where hoisting the tf.BatchFunction leads to + // a compilation error; such a case may occur in unit tests. + if (llvm::isa(op)) return false; + // Non-side-effecting ops can be hoisted. if (mlir::isMemoryEffectFree(op)) return true; @@ -402,7 +406,7 @@ void LowerTFSavedModelPass::HoistInvariantOps(mlir::ModuleOp module) { } else if (auto func = llvm::dyn_cast(op)) { if (!IsSessionInitializer(func)) return; FindCalleesRecursive(symbol_table, func, init_callees); - } else if (op->getName().getStringRef().str() == "tf.XlaLaunch") { + } else if (llvm::isa(op)) { // TODO(b/275095412): Clean up MLIR XLA functions after they are written // back to function library, so that we don't need to do special handling // for those functions here. diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD index 1bb99fa64ebaf7..7d28571db5030a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD @@ -65,6 +65,7 @@ cc_library( ":tpu_conversion_patterns", ":util", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", "//tensorflow/compiler/mlir/tensorflow/ir/host_runtime:tensorflow_tfrt_ops_inc_gen", "//tensorflow/compiler/mlir/tfrt:constants", @@ -78,12 +79,14 @@ cc_library( "//tensorflow/core/tfrt/fallback:fallback_state", "//tensorflow/core/tfrt/fallback:op_kernel_runner_cache", "@com_google_protobuf//:protobuf_headers", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", + "@local_tsl//tsl/platform:status", ], ) @@ -220,6 +223,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", ], diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc index 8271a5c796e5c4..37ddf0b1bf076d 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "google/protobuf/text_format.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -39,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tfrt/constants.h" #include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" #include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h" @@ -52,6 +54,7 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/tfrt/fallback/fallback_state.h" #include "tensorflow/core/tfrt/fallback/op_kernel_runner_cache.h" +#include "tsl/platform/status.h" namespace tensorflow { namespace mlrt_compiler { @@ -343,6 +346,26 @@ class IfrtLoadVariableOpConversion } }; +// Convert tf.IfrtRestoreVariableOp to tf_mlrt.IfrtRestoreVariableOp +class IfrtRestoreVariableOpConversion + : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + mlir::TF::IfrtRestoreVariableOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto new_op = rewriter.create( + op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], + adaptor.getOperands()[2], + adaptor.getOperands().slice(3, adaptor.getOperands().size() - 3), + op.getRestoredDtypes()); + rewriter.replaceOp(op, new_op); + + return mlir::success(); + } +}; + std::optional DecodeLongName(mlir::Location loc) { if (auto name_loc = loc.dyn_cast()) { return name_loc.getName().str(); @@ -422,6 +445,18 @@ class ExecuteOpConversion final : public mlir::ConversionPattern { // TODO(b/173017701): Avoid fallback for ops within XLA GPU clusters. if (!UseFallback(op)) return mlir::failure(); + if (auto const_op = llvm::dyn_cast(op)) { + tensorflow::TensorProto tensor_proto; + auto status = ConvertToTensorProto(const_op.getValue(), &tensor_proto); + if (!status.ok()) + return const_op.emitError(tsl::NullTerminatedMessage(status)); + + rewriter.replaceOpWithNewOp( + op, rewriter.getType(), + tensor_proto.SerializeAsString()); + return mlir::success(); + } + // The assign_op_key pass should have ran. if (!op->hasAttr(tensorflow::tfrt_compiler::kOpKeyAttrName)) return op->emitError("does not have op_key defined"); @@ -1189,7 +1224,8 @@ class TfToMlrtConversionPass patterns.add(&context, &type_converter_, &symbol_table); patterns.add(&context); + IfrtRestoreVariableOpConversion, TFAwaitOpConversion, + TFPromiseOpConversion>(&context); patterns.add(type_converter_, &context); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc index d9e1b7f73ac0c8..a1f9d401f5c485 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc @@ -35,9 +35,9 @@ bool UseFallback(mlir::Operation *op) { return !llvm::isa< mlir::TF::_TfrtSetResourceOp, mlir::TF::_TfrtGetResourceOp, mlir::TF::BatchFunctionOp, mlir::TF::CaseOp, mlir::TF::IfrtLoadVariableOp, - mlir::TF::StatefulPartitionedCallOp, mlir::TF::PartitionedCallOp, - mlir::TF::LegacyCallOp, mlir::TF::IfOp, mlir::TF::WhileOp, - mlir::TF::TPUCompileMlirAndExecuteOp>(op); + mlir::TF::IfrtRestoreVariableOp, mlir::TF::StatefulPartitionedCallOp, + mlir::TF::PartitionedCallOp, mlir::TF::LegacyCallOp, mlir::TF::IfOp, + mlir::TF::WhileOp, mlir::TF::TPUCompileMlirAndExecuteOp>(op); } } // namespace mlrt_compiler diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index 66aee10db7e050..f61a087e782704 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -147,35 +147,6 @@ StatusOr> ExportXlaFunctions( } // namespace -Status ConvertFunctionToBef( - mlir::StringRef function_name, const tensorflow::FunctionBody* fbody, - const FunctionLibraryDefinition& flib_def, - tfrt::ArrayRef devices, - const tensorflow::TfrtFunctionCompileOptions& options, - tfrt::BefBuffer* bef_buffer) { - mlir::MLIRContext context; - // FunctionDef -> TF Dialect - auto expected_module = - tensorflow::ConvertFunctionToMlir(fbody, flib_def, &context); - - if (!expected_module.ok()) - return absl::InternalError(absl::StrCat( - "Failed to convert function to mlir for function ", function_name.str(), - ". Error: ", expected_module.status().message())); - - auto module = std::move(expected_module).value(); - - // Attach devices to the MLIR module. - if (!devices.empty()) { - mlir::Builder builder(module->getContext()); - module->getOperation()->setAttr("tf.devices", - builder.getStrArrayAttr(devices)); - } - - // TF Dialect -> BEF - return tensorflow::CompileTFMLIRToBEF(options, module.get(), bef_buffer); -} - Status ConvertTfMlirToRuntimeExecutable( const TfrtCompileOptions& options, mlir::ModuleOp module, absl::FunctionRef devices, - const tensorflow::TfrtFunctionCompileOptions& options, - tfrt::BefBuffer* bef_buffer); - // Converts an MLIR `module` in TF dialect to TFRT's Binary Executable Format. // If `fallback_state` is not null, the MLIR functions for XLA clusters in // the form of XlaLaunch will be exported and added to the function library when diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index d73899b6f85ecb..86e2e269e4d329 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -1,12 +1,20 @@ load( - "//tensorflow:tensorflow.bzl", - "check_deps", - "tf_cc_binary", + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm_is_configured", +) +load( + "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", ) load( "@local_xla//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured", ) +load( + "//tensorflow:tensorflow.bzl", + "check_deps", + "tf_cc_binary", +) load( "//tensorflow/core/platform:build_config.bzl", "tf_proto_library", @@ -19,14 +27,6 @@ load( "if_llvm_system_z_available", "if_llvm_x86_available", ) -load( - "@local_config_rocm//rocm:build_defs.bzl", - "if_rocm_is_configured", -) -load( - "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", -) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -65,6 +65,7 @@ cc_library( "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:ComplexToStandard", "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:DataLayoutInterfaces", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToGPURuntimeTransforms", @@ -88,6 +89,7 @@ cc_library( "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:ShapeToStandard", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:all_passes", # fixdeps: keep @@ -125,6 +127,7 @@ tf_cc_binary( "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:BufferizationInterfaces", "@llvm-project//mlir:ExecutionEngineUtils", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:MemRefTransforms", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index acec5d7ae27ff5..42d679c35d0173 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -1,6 +1,6 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -86,6 +86,7 @@ cc_library( "@com_google_absl//absl/status", "@llvm-project//mlir:AllocationOpInterface", "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 18c5ab830d4722..c4abb6420d9b38 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -1,5 +1,4 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", @@ -8,6 +7,7 @@ load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -39,6 +39,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -57,6 +58,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:chlo_ops", ], @@ -73,6 +75,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -148,6 +151,7 @@ cc_library( "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorTransforms", "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToLLVM", @@ -213,8 +217,10 @@ cc_library( "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:ShapeToStandard", "@llvm-project//mlir:ShapeTransforms", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo:transforms_passes", ], diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD index e25d2229c605c8..a7d9610a472308 100644 --- a/tensorflow/compiler/mlir/tosa/BUILD +++ b/tensorflow/compiler/mlir/tosa/BUILD @@ -102,6 +102,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TosaDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_tsl//tsl/framework/fixedpoint", ], @@ -157,6 +158,7 @@ cc_library( "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TosaDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -219,6 +221,7 @@ cc_library( "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", "@llvm-project//mlir:TosaDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -248,6 +251,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TosaDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 498c9bfe11bbae..d255b67ccff83f 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1,7 +1,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") -load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites", "tf_xla_py_strict_test") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_cuda_cc_test") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") +load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites", "tf_xla_py_strict_test") load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", @@ -260,6 +260,7 @@ tf_xla_py_strict_test( "cpu", "gpu", "gpu_a100", + "gpu_h100", ], python_version = "PY3", shard_count = 2, @@ -679,6 +680,7 @@ tf_xla_py_strict_test( "cpu", "gpu", "gpu_a100", + "gpu_h100", ], python_version = "PY3", tags = [ @@ -925,6 +927,7 @@ tf_xla_py_strict_test( "cpu", "gpu", "gpu_a100", + "gpu_h100", ], python_version = "PY3", shard_count = 10, @@ -1517,6 +1520,7 @@ tf_xla_py_strict_test( disabled_backends = [ "gpu", "gpu_a100", + "gpu_h100", ], enable_mlir_bridge = True, python_version = "PY3", @@ -1554,6 +1558,7 @@ tf_xla_py_strict_test( disabled_backends = [ "gpu", "gpu_a100", + "gpu_h100", ], enable_mlir_bridge = True, python_version = "PY3", @@ -1591,6 +1596,7 @@ tf_xla_py_strict_test( disabled_backends = [ "gpu", "gpu_a100", + "gpu_h100", ], # TODO(b/232442915): Enable MLIR. enable_mlir_bridge = False, @@ -2343,6 +2349,7 @@ tf_xla_py_strict_test( disabled_backends = [ "gpu", "gpu_a100", + "gpu_h100", ], enable_mlir_bridge = True, python_version = "PY3", @@ -2378,6 +2385,7 @@ tf_xla_py_strict_test( disabled_backends = [ "gpu", "gpu_a100", + "gpu_h100", ], enable_mlir_bridge = False, python_version = "PY3", @@ -2605,6 +2613,7 @@ tf_xla_py_strict_test( "cpu", "gpu", "gpu_a100", + "gpu_h100", ], tags = [ "no_pip", @@ -2633,6 +2642,7 @@ tf_xla_py_strict_test( "cpu_ondemand", "gpu", "gpu_a100", + "gpu_h100", ], enable_mlir_bridge = False, main = "where_op_test.py", @@ -2719,6 +2729,7 @@ tf_xla_py_strict_test( "cpu_ondemand", "gpu", "gpu_a100", + "gpu_h100", ], enable_mlir_bridge = False, python_version = "PY3", @@ -2883,6 +2894,7 @@ tf_xla_py_strict_test( "cpu", "gpu", "gpu_a100", + "gpu_h100", ], python_version = "PY3", tags = [ diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 7343bb9b89efce..ce6b626683e281 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -1,7 +1,7 @@ """Build rules for Tensorflow/XLA testing.""" -load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:strict.default.bzl", "py_strict_test") +load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow/compiler/tests:plugin.bzl", "plugins") load( "//tensorflow/core/platform:build_config_root.bzl", @@ -84,7 +84,7 @@ def tf_xla_py_test( "--test_device=" + cpu_xla_device, "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_COMPLEX128", ] - elif backend in ("gpu", "gpu_a100"): + elif backend in ("gpu", "gpu_a100", "gpu_h100"): backend_args += [ "--test_device=" + gpu_xla_device, "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_COMPLEX128,DT_BFLOAT16", @@ -125,7 +125,7 @@ def tf_xla_py_test( # # This is for testing book keeping because the bridge does not have any gpu specific # logic at this time, so CPU testing is good enough and cheaper. - extra_tag = ["ondemand"] if backend in ("gpu", "gpu_a100") else [] + extra_tag = ["ondemand"] if backend in ("gpu", "gpu_a100", "gpu_h100") else [] elif has_mlir_dep: # Some tests run only with mlir_bridge by explicitly adding the MLIR # bridge dep so if the dep is already present skip non MLIR diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 91ef722b52db86..498aa0f91e487a 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -3,6 +3,11 @@ # and provide TensorRT operators and converter package. # APIs are meant to change over time. +load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") +load( + "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) load("//tensorflow:strict.default.bzl", "py_strict_library") # Placeholder: load py_proto_library @@ -28,11 +33,6 @@ load( "if_static", ) load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load( - "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", - "cuda_rpath_flags", -) -load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 7e1d80e9e8676d..01e85cc7c6cfc7 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1,24 +1,25 @@ +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") + # copybara:uncomment_begin(google-only) # load("//devtools/deps/check:deps_check.bzl", "check_dependencies") # copybara:uncomment_end +load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") +load( + "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) +load("@local_xla//xla:xla.bzl", "xla_py_proto_library") +load("@local_xla//xla/service/cpu:build_defs.bzl", "runtime_copts") load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow:tensorflow.bzl", "if_google", "if_libtpu", "tf_cc_binary", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_gen_op_wrapper_py", "tf_openmp_copts") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "filegroup", "get_compatible_with_portable") -load("@local_xla//xla:xla.bzl", "xla_py_proto_library") -load("@local_xla//xla/service/cpu:build_defs.bzl", "runtime_copts") load( "//tensorflow/core/platform:build_config.bzl", "tf_additional_tensor_coding_deps", "tf_proto_library", ) load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") -load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") -load( - "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", -) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -302,6 +303,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@ducc//:fft_wrapper", "@eigen_archive//:eigen3", + "@llvm-project//mlir:TransformUtils", "@local_xla//xla:empty", "//tensorflow/core/framework:numeric_types", "//tensorflow/core/platform:bfloat16", @@ -418,6 +420,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", + "@com_google_absl//absl/types:span", "@local_xla//xla:cpu_function_runtime", "@local_xla//xla:shape_util", "@local_xla//xla:statusor", @@ -1169,11 +1172,11 @@ cc_library( hdrs = ["mlir_bridge_pass.h"], visibility = [":internal"], deps = [ - ":tf2xla_defs", ":xla_op_registry", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:device_util", "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops", "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index ccfd67e223b1b3..7ead605ca65c07 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.default.bzl", "tf_gen_op_wrapper_cc") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index f06596bfe6530b..6a60149d7cc4a1 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -1,3 +1,4 @@ +load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") load( "//tensorflow:tensorflow.bzl", "if_cuda_or_rocm", @@ -11,7 +12,6 @@ load( "tf_proto_library", ) load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index 6ab1793d493eaf..5c19b9fe1014d3 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -160,7 +160,7 @@ absl::Status XlaCallModuleLoader::SetPlatformIndex( } } - if (platform_index < 0) return tsl::OkStatus(); + if (platform_index < 0) return absl::OkStatus(); VLOG(3) << "XlaCallModule setting the platform_index to " << platform_index << " for platform " << compilation_platform << "."; mlir::Block &main_body = main_.front(); @@ -193,7 +193,7 @@ absl::Status XlaCallModuleLoader::SetPlatformIndex( main_.eraseArgument(0); platform_index_arg_set_ = true; - return tsl::OkStatus(); + return absl::OkStatus(); } static mlir::stablehlo::CustomCallOp MakeShapeRefinementOperandWrapper( @@ -232,13 +232,13 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( VLOG(3) << "XlaCallModule skipping shape refinement due to module " << " attribute " << kUsesShapePolymorphismAttr.str() << "=" << mlir::debugString(uses_shape_poly_attr); - return tsl::OkStatus(); + return absl::OkStatus(); } } else { VLOG(3) << "XlaCallModule skipping shape refinement due to module " << " attribute " << kUsesShapePolymorphismAttr.str() << " missing"; - return tsl::OkStatus(); + return absl::OkStatus(); } } // Add the tokens to the input_shapes. Starting with version 9, the main @@ -430,7 +430,7 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( DumpMlirOpToFile("xla_call_module.after_shape_refinement", *module_); } - return tsl::OkStatus(); + return absl::OkStatus(); } absl::Status XlaCallModuleLoader::LoadModule( @@ -527,7 +527,7 @@ absl::Status XlaCallModuleLoader::LoadModule( " arguments of which ", nr_platform_args, " platform index arguments, ", "and ", nr_token_arguments, " token arguments.")); } - return tsl::OkStatus(); + return absl::OkStatus(); } absl::Status XlaCallModuleLoader::ValidateDialect() { @@ -550,7 +550,7 @@ absl::Status XlaCallModuleLoader::ValidateDialect() { absl::StrCat("Module has unsupported dialects: ", diag_handler.ConsumeStatus().ToString())); } - return tsl::OkStatus(); + return absl::OkStatus(); } absl::Status XlaCallModuleLoader::ValidateStaticShapes() { @@ -563,8 +563,8 @@ absl::Status XlaCallModuleLoader::LowerModuleToMhlo() { mlir::PassManager pm(module_->getContext()); applyTensorflowAndCLOptions(pm); pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); - pm.addNestedPass(mlir::mhlo::createChloLegalizeToHloPass( - /*legalizeBroadcasts=*/true, /*expandCompositions=*/true)); + pm.addNestedPass( + mlir::mhlo::createChloLegalizeToHloPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); // In order to export to XLA, we must sink constants to control flow // regions, since XLA uses functional control flow. diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 0402508fe92f56..c24654c894b34f 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -29,13 +29,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h" #include "tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h" -#include "tensorflow/compiler/tf2xla/tf2xla_defs.h" +// #include "tensorflow/compiler/tf2xla/tf2xla_defs.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/device.h" @@ -162,16 +163,28 @@ MlirOptimizationPassState GetPassStateImpl( << " Bridge, disabled by user. " "The fallback will evaluate."; metrics::UpdateTfMlirBridgeFirstPhaseCounter( - is_supported_by_replicated_brige ? "tpu" : "cpu/gpu", "v2", true, - "disabled_by_user"); + /*bridge_type*/ is_supported_by_replicated_brige + ? mlir::TF::kMlirPh1BridgeCounterReplicated + : mlir::TF::kMlirPh1BridgeCounterNonReplicated, + /*bridge_version*/ mlir::TF::kMlirPh1BridgeCounterV2, + /*device_type*/ + is_supported_by_replicated_brige + ? mlir::TF::kMlirPh1BridgeCounterTpu + : mlir::TF::kMlirPh1BridgeCounterNonTpu, + /*fallback_enabled*/ true, + /*result*/ "disabled_by_user"); return MlirOptimizationPassState::Disabled; } case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis: // Graph analysis only runs on TPU graph. VLOG(1) << "Skipping MLIR TPU Bridge, disabled because the " "graph has unsupported features. The fallback will evaluate."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v2", true, - "invalid_graph"); + metrics::UpdateTfMlirBridgeFirstPhaseCounter( + /*bridge_type*/ mlir::TF::kMlirPh1BridgeCounterReplicated, + /*bridge_version*/ mlir::TF::kMlirPh1BridgeCounterV2, + /*device_type*/ mlir::TF::kMlirPh1BridgeCounterTpu, + /*fallback_enabled*/ true, + /*result*/ "invalid_graph"); // We set `uses_uninitialized_resource_args` to false here because the // first phase of the bridge is not affected by uninitialized resource // args. @@ -305,16 +318,24 @@ MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState( VLOG(1) << "Skipping MLIR Replicated Bridge V1 Compat, MLIR Replicated " "bridge disabled " "by user. Fallback will evaluate."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v1", true, - "disabled_by_user"); + metrics::UpdateTfMlirBridgeFirstPhaseCounter( + /*bridge_type*/ mlir::TF::kMlirPh1BridgeCounterReplicated, + /*bridge_version*/ mlir::TF::kMlirPh1BridgeCounterV1, + /*device_type*/ mlir::TF::kMlirPh1BridgeCounterTpu, + /*fallback_enabled*/ true, + /*result*/ "disabled_by_user"); return MlirOptimizationPassState::Disabled; case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis: VLOG(1) << "Skipping MLIR Replicated Bridge V1 Compat, MLIR Replicated " "bridge disabled " "because graph has unsupported features. Old bridge will " "evaluate."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v1", true, - "invalid_graph"); + metrics::UpdateTfMlirBridgeFirstPhaseCounter( + /*bridge_type*/ mlir::TF::kMlirPh1BridgeCounterReplicated, + /*bridge_version*/ mlir::TF::kMlirPh1BridgeCounterV1, + /*device_type*/ mlir::TF::kMlirPh1BridgeCounterTpu, + /*fallback_enabled*/ true, + /*result*/ "invalid_graph"); // We set `uses_uninitialized_resource_args` to false here because the // first phase of the bridge is not affected by uninitialized resource // args. diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc index 3c453e88c9dc10..b2e52f6d0dbda5 100644 --- a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc @@ -137,12 +137,24 @@ Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) { // Compile the graph to HLO. GraphDebugInfo debug_info; std::vector returns(1); - TF_RETURN_IF_ERROR(BuildHloFromGraph( - *graph, *ctx->builder(), *ctx_res->GetContext(), xla_params, returns, - mlir::SpanToArrayRef(xla_args), control_rets, - device->device_type(), - *ctx->function_library()->GetFunctionLibraryDefinition(), debug_info, - {})); + auto build_hlo = [&](bool unconditionally_use_output_shapes) { + return BuildHloFromGraph( + *graph, *ctx->builder(), *ctx_res->GetContext(), xla_params, returns, + unconditionally_use_output_shapes, + mlir::SpanToArrayRef(xla_args), control_rets, + device->device_type(), + *ctx->function_library()->GetFunctionLibraryDefinition(), debug_info, + {}); + }; + + // Some of the operations that come through here do not know how to set their + // own output shapes (e.g. _XlaHostComputeMlir') so we may need to use the + // unconditional output shapes option. However, many graphs fail if we do it + // unconditionally so try both. + if (!build_hlo(/*unconditionally_use_output_shapes=*/false).ok()) { + // If that failed, then try again with the unconditional set true + TF_RETURN_IF_ERROR(build_hlo(/*unconditionally_use_output_shapes=*/true)); + } // Set context outputs. for (int i = 0, end = returns.size(); i < end; ++i) { diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index ba26c4fe54b31a..6adab4c6c7f6b4 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -1,10 +1,10 @@ load("//tensorflow:strict.default.bzl", "py_strict_library") -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index 128d865b8c63cf..50dbf000b03501 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -1,8 +1,8 @@ +load("//tensorflow:tensorflow.default.bzl", "tf_custom_op_py_library") load( "//tensorflow/core/platform:build_config.bzl", "tf_py_clif_cc", ) -load("//tensorflow:tensorflow.default.bzl", "tf_custom_op_py_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 8001c6dc47e18e..8af2c21994d4c4 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -752,71 +752,25 @@ Status XlaCompiler::CompileSingleOp( const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::SingleOpCompileArgument& single_op_compile_argument, absl::Span args, XlaCompiler::CompilationResult* result) { - const std::vector& result_dtypes = - single_op_compile_argument.output_dtypes; const NodeDef& node_def = single_op_compile_argument.node_def; TF_ASSIGN_OR_RETURN( auto graph, CreateSingleOpGraph(node_def, args, single_op_compile_argument.output_dtypes)); - auto compile_with_old_bridge = [&]() { - *result = {}; - Status status = ADD_SOURCE_LOCATION(CompileGraph( - compile_options, node_def.name(), std::move(graph), args, result)); - if (status.ok()) { - tensorflow::metrics::IncrementPhase2XlaCompilerCounter( - tensorflow::metrics::Phase2XlaCompilerMetric:: - kCompileSingleOpXlaBuilderSuccess); - } else { - tensorflow::metrics::IncrementPhase2XlaCompilerCounter( - tensorflow::metrics::Phase2XlaCompilerMetric:: - kCompileSingleOpXlaBuilderFailure); - } - return status; - }; - - const ConfigProto* config = &(single_op_compile_argument.config_proto); - auto bridge_rollout = GetMlirBridgeRolloutState( - config ? std::optional(*config) : std::nullopt); - if (bridge_rollout == - ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED || - node_def.op() == "VarIsInitializedOp" || - (bridge_rollout != - ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED && - options_.device_type.type_string() != DEVICE_TPU_XLA_JIT)) { - return compile_with_old_bridge(); - } - - GraphDebugInfo debug_info; - std::vector control_rets; - if (result_dtypes.empty()) { - control_rets.push_back(node_def.name()); - } - - bool mlir_enabled = (bridge_rollout == - ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED); - VLOG(1) << "Attempting MLIR bridge." - << (mlir_enabled ? " MLIR is explicitly enabled." : ""); - auto mlir_result = CompileGraphToXlaHlo( - *graph, mlir::SpanToArrayRef(args), control_rets, - options_.device_type.type_string(), compile_options.use_tuple_arg, - /*analyse_graph=*/!mlir_enabled, *options_.flib_def, debug_info, - options_.shape_determination_fns, result); - - if (mlir_result.ok() || mlir_enabled) { + *result = {}; + Status status = ADD_SOURCE_LOCATION(CompileGraph( + compile_options, node_def.name(), std::move(graph), args, result)); + if (status.ok()) { + tensorflow::metrics::IncrementPhase2XlaCompilerCounter( + tensorflow::metrics::Phase2XlaCompilerMetric:: + kCompileSingleOpXlaBuilderSuccess); + } else { tensorflow::metrics::IncrementPhase2XlaCompilerCounter( tensorflow::metrics::Phase2XlaCompilerMetric:: - kCompileSingleOpMlirSuccess); - return mlir_result; + kCompileSingleOpXlaBuilderFailure); } - tensorflow::metrics::IncrementPhase2XlaCompilerCounter( - tensorflow::metrics::Phase2XlaCompilerMetric:: - kCompileSingleOpMlirFailure); - VLOG(1) << "Failed second phase of the MLIR bridge. Will " - "retry with the old bridge. MLIR bridge compilation status: " - << mlir_result; - return compile_with_old_bridge(); + return status; } Status XlaCompiler::CompileFunction( diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 784c012a0274bc..754d018cc5781c 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/tf2xla.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" @@ -50,6 +51,18 @@ absl::StatusOr ComputeResultIndex( return result_slice.index(); } +// Returns the number of results. +int CountResults( + absl::Span buffer_infos) { + int num_results = 0; + for (const auto& info : buffer_infos) { + if (info.is_result_parameter()) { + ++num_results; + } + } + return num_results; +} + // Collect names from `entries`, where T is one of // tf2xla::{Feed,Fetch,Variable}. We hold the actual strings in nonempty_names, // and hold arrays of pointers in name_ptrs, terminated by a nullptr entry. @@ -146,6 +159,7 @@ XlaJitCompiledCpuFunction::Compile( xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos); TF_ASSIGN_OR_RETURN(size_t result_index, ComputeResultIndex(buffer_assignment)); + const int num_results = CountResults(buffer_infos); std::unique_ptr jit_unique_ptr( new XlaJitCompiledCpuFunction); @@ -173,6 +187,8 @@ XlaJitCompiledCpuFunction::Compile( &jit->static_data_, jit->arg_index_table_.size()); XlaCompiledCpuFunction::set_static_data_num_variables(&jit->static_data_, config.variable_size()); + XlaCompiledCpuFunction::set_static_data_num_results(&jit->static_data_, + num_results); XlaCompiledCpuFunction::set_static_data_result_index(&jit->static_data_, result_index); // Optional metadata is collected and set below. diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index 399826ac12ed55..787d67674a2c8e 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -176,6 +176,8 @@ TEST(XlaJitCompiledCpuFunction, Sum) { XlaJitCompiledCpuFunction::Compile(graph_def, config, xla::ExecutableBuildOptions())); XlaCompiledCpuFunction function(jit->StaticData()); + ASSERT_EQ(function.num_args(), 2); + ASSERT_EQ(function.num_results(), 1); // Run the function and check results. *static_cast(function.arg_data(0)) = 10; @@ -258,6 +260,8 @@ TEST(XlaJitCompiledCpuFunction, SumVariable) { XlaJitCompiledCpuFunction::Compile(graph_def, config, xla::ExecutableBuildOptions())); XlaCompiledCpuFunction function(jit->StaticData()); + ASSERT_EQ(function.num_args(), 2); + ASSERT_EQ(function.num_results(), 2); // Run the function and check results. *static_cast(function.arg_data(0)) = 10; diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index de77413927f52e..50a5319db026b2 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -63,6 +63,11 @@ # Placeholder: load py_proto_library load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") +load( + "@local_xla//xla/tsl/mkl:build_defs.bzl", + "if_mkl", +) load( "//tensorflow:tensorflow.bzl", "if_android", @@ -83,10 +88,6 @@ load( "transitive_hdrs", ) load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "get_compatible_with_portable", "tensorflow_opensource_extra_deps", "tf_monitoring_framework_deps", "tf_selective_registration_deps") -load( - "@local_xla//xla/tsl/mkl:build_defs.bzl", - "if_mkl", -) # For platform specific build config load( @@ -117,7 +118,6 @@ load( "//tensorflow/core/platform:rules_cc.bzl", "cc_library", ) -load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -1035,7 +1035,6 @@ cc_library( "//tensorflow/core:mobile_additional_lib_deps", "//tensorflow/core/platform:resource", "//tensorflow/core/util:stats_calculator_portable", - "@local_xla//xla:bazel_issue_21519", ] + tf_portable_proto_lib() + tf_portable_deps_no_runtime(), alwayslink = 1, ) @@ -1717,8 +1716,8 @@ tf_cuda_library( "@local_tsl//tsl/framework:cancellation", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:stringpiece", - "@local_tsl//tsl/util:command_line_flags", - "@local_tsl//tsl/util:device_name_utils", + "@local_xla//xla/tsl/util:command_line_flags", + "@local_xla//xla/tsl/util:device_name_utils", ] + if_cuda([ "@local_config_cuda//cuda:cudnn_header", ]) + if_static( diff --git a/tensorflow/core/activity_watcher/BUILD b/tensorflow/core/activity_watcher/BUILD index 0526bccd5a1672..d471f4f892c4fc 100644 --- a/tensorflow/core/activity_watcher/BUILD +++ b/tensorflow/core/activity_watcher/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "if_not_mobile") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 8d96a2cbaff5f8..76b8cc01324619 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -7,19 +7,19 @@ # :java_api_def load( - "//tensorflow:tensorflow.bzl", - "tf_cc_binary", - "tf_cc_test", + "@local_config_tensorrt//:build_defs.bzl", + "if_tensorrt", ) load( "@local_xla//xla/tsl/mkl:build_defs.bzl", "if_mkl", ) -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( - "@local_config_tensorrt//:build_defs.bzl", - "if_tensorrt", + "//tensorflow:tensorflow.bzl", + "tf_cc_binary", + "tf_cc_test", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/core/api_def/base_api/api_def_ComputeDedupDataSizeV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ComputeDedupDataSizeV2.pbtxt new file mode 100644 index 00000000000000..f8066663fa20c3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ComputeDedupDataSizeV2.pbtxt @@ -0,0 +1,40 @@ +op { + graph_op_name: "ComputeDedupDataSizeV2" + visibility: HIDDEN + out_arg { + name: "num_elements" + description: <