diff --git a/.bazelrc b/.bazelrc index d8990ac5c12cc5..d7ae76f096431a 100644 --- a/.bazelrc +++ b/.bazelrc @@ -253,7 +253,7 @@ build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.3" build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-17/bin/clang" build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:cuda_clang_official --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain" +build:cuda_clang_official --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" # Build with nvcc for CUDA and clang for host build:nvcc_clang --config=cuda @@ -293,6 +293,11 @@ build:rocm --define=using_rocm_hipcc=true build:rocm --define=tensorflow_mkldnn_contraction_kernel=0 build:rocm --repo_env TF_NEED_ROCM=1 +build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain +build:sycl --define=using_sycl=true +build:sycl --define=tensorflow_mkldnn_contraction_kernel=0 +build:sycl --repo_env TF_NEED_SYCL=1 + # Options to disable default on features build:noaws --define=no_aws_support=true build:nogcp --define=no_gcp_support=true @@ -497,12 +502,12 @@ build:rbe_linux --host_linkopt=-lm build:rbe_linux_cpu --config=rbe_linux # Linux cpu and cuda builds share the same toolchain now. -build:rbe_linux_cpu --host_crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain" -build:rbe_linux_cpu --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain" -build:rbe_linux_cpu --extra_toolchains="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cpu --extra_execution_platforms="@sigbuild-r2.16-clang_config_platform//:platform" -build:rbe_linux_cpu --host_platform="@sigbuild-r2.16-clang_config_platform//:platform" -build:rbe_linux_cpu --platforms="@sigbuild-r2.16-clang_config_platform//:platform" +build:rbe_linux_cpu --host_crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" +build:rbe_linux_cpu --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" +build:rbe_linux_cpu --extra_toolchains="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cpu --extra_execution_platforms="@sigbuild-r2.17-clang_config_platform//:platform" +build:rbe_linux_cpu --host_platform="@sigbuild-r2.17-clang_config_platform//:platform" +build:rbe_linux_cpu --platforms="@sigbuild-r2.17-clang_config_platform//:platform" # This is needed for all Clang17 builds but must not be present in GCC builds. build:rbe_linux_cpu --copt=-Wno-error=unused-command-line-argument # This was added in clang-16 by https://reviews.llvm.org/D133574. @@ -511,7 +516,7 @@ build:rbe_linux_cpu --copt=-Wno-error=unused-command-line-argument # See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. build:rbe_linux_cpu --copt=-Wno-gnu-offsetof-extensions # Python config is the same across all containers because the binary is the same -build:rbe_linux_cpu --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.16-clang_config_python" +build:rbe_linux_cpu --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.17-clang_config_python" build:rbe_linux_cpu --python_path="/usr/bin/python3" # These you may need to change for your own GCP project. common:rbe_linux_cpu --remote_instance_name=projects/tensorflow-testing/instances/default_instance @@ -532,9 +537,9 @@ build:rbe_linux_cuda --config=cuda_clang_official build:rbe_linux_cuda --config=rbe_linux_cpu # For Remote build execution -- GPU configuration build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.16-clang_config_cuda" -build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.16-clang_config_tensorrt" -build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_config_nccl" +build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.17-clang_config_cuda" +build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.17-clang_config_tensorrt" +build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.17-clang_config_nccl" test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda @@ -639,7 +644,7 @@ test:release_linux_base --test_summary=short # Use the Clang toolchain to compile build:release_cpu_linux --config=release_linux_base -build:release_cpu_linux --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain" +build:release_cpu_linux --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" build:release_gpu_linux --config=release_cpu_linux # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. diff --git a/.github/workflows/update-rbe.yml b/.github/workflows/update-rbe.yml index d670cd6040401d..94929afefbea9f 100644 --- a/.github/workflows/update-rbe.yml +++ b/.github/workflows/update-rbe.yml @@ -117,6 +117,18 @@ jobs: map sigbuild-r2.16-clang-python3.10 2.16-python3.10 map sigbuild-r2.16-clang-python3.11 2.16-python3.11 map sigbuild-r2.16-clang-python3.12 2.16-python3.12 + # TF 2.17 + map sigbuild-r2.17 2.17-python3.11 + map sigbuild-r2.17-python3.9 2.17-python3.9 + map sigbuild-r2.17-python3.10 2.17-python3.10 + map sigbuild-r2.17-python3.11 2.17-python3.11 + map sigbuild-r2.17-python3.12 2.17-python3.12 + # TF 2.17 + Clang (containers are the same, but env vars in configs.bzl are different) + map sigbuild-r2.17-clang 2.17-python3.11 + map sigbuild-r2.17-clang-python3.9 2.17-python3.9 + map sigbuild-r2.17-clang-python3.10 2.17-python3.10 + map sigbuild-r2.17-clang-python3.11 2.17-python3.11 + map sigbuild-r2.17-clang-python3.12 2.17-python3.12 - name: Create Pull Request with changes uses: peter-evans/create-pull-request@2b011faafdcbc9ceb11414d64d0573f37c774b04 # v4.2.3 with: diff --git a/RELEASE.md b/RELEASE.md index 8c9ba51d7993ae..3c6198b60d1918 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -59,6 +59,15 @@ built with support for a given CPU target. This can be useful for skipping target-specific tests if a target is not supported. +* `tf.data` + * Support `data.experimental.distribued_save`. `distribued_save` uses + tf.data service + (https://www.tensorflow.org/api_docs/python/tf/data/experimental/service) + to write distributed dataset snapshots. The call is non-blocking and + returns without waiting for the snapshot to finish. Setting `wait=True` to + `tf.data.Dataset.load` allows the snapshots to be read while they are + being written. + ### Bug Fixes and Other Changes * @@ -79,6 +88,13 @@ `experimental_default_delegate_latest_features` to enable all default delegate features. +* `tf.data` + * Add `wait` to `tf.data.Dataset.load`. If `True`, for snapshots written + with `distributed_save`, it reads the snapshot while it is being written. + For snapshots written with regular `save`, it waits for the snapshot until + it's finished. The default is `False` for backward compatibility. Users of + `distributed_save` are recommended to set it to `True`. + ## Thanks to our Contributors This release contains contributions from many people at Google, as well as: diff --git a/WORKSPACE b/WORKSPACE index 675a9481283514..cb024a13a19a47 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,6 +2,8 @@ workspace(name = "org_tensorflow") +# buildifier: disable=load-on-top + # We must initialize hermetic python first. load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") @@ -14,6 +16,12 @@ http_archive( ], ) +http_archive( + name = "rules_java", + sha256 = "c73336802d0b4882e40770666ad055212df4ea62cfa6edf9cb0f9d29828a0934", + url = "https://github.com/bazelbuild/rules_java/releases/download/5.3.5/rules_java-5.3.5.tar.gz", +) + http_archive( name = "rules_python", sha256 = "9d04041ac92a0985e344235f5d946f71ac543f1b1565f2cdbc9a2aaee8adf55b", @@ -21,6 +29,7 @@ http_archive( url = "https://github.com/bazelbuild/rules_python/releases/download/0.26.0/rules_python-0.26.0.tar.gz", ) +# buildifier: disable=same-origin-load load("@rules_python//python:repositories.bzl", "py_repositories") py_repositories() diff --git a/ci/official/requirements_updater/WORKSPACE b/ci/official/requirements_updater/WORKSPACE index f9a116a6a3153e..e29f586f933c6a 100644 --- a/ci/official/requirements_updater/WORKSPACE +++ b/ci/official/requirements_updater/WORKSPACE @@ -2,6 +2,8 @@ workspace(name = "requirements_updater") +# buildifier: disable=load-on-top + load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") http_archive( @@ -20,6 +22,7 @@ http_archive( url = "https://github.com/bazelbuild/rules_python/releases/download/0.26.0/rules_python-0.26.0.tar.gz", ) +# buildifier: disable=same-origin-load load("@rules_python//python:repositories.bzl", "py_repositories") py_repositories() diff --git a/ci/official/wheel_test/WORKSPACE b/ci/official/wheel_test/WORKSPACE index d52a3ed895173b..db46144dadbbb1 100644 --- a/ci/official/wheel_test/WORKSPACE +++ b/ci/official/wheel_test/WORKSPACE @@ -2,6 +2,8 @@ workspace(name = "wheel_test") +# buildifier: disable=load-on-top + load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") http_archive( @@ -20,6 +22,7 @@ http_archive( url = "https://github.com/bazelbuild/rules_python/releases/download/0.26.0/rules_python-0.26.0.tar.gz", ) +# buildifier: disable=same-origin-load load("@rules_python//python:repositories.bzl", "py_repositories") py_repositories() diff --git a/configure.py b/configure.py index 66427431b42c16..0081eeabf66bcc 100644 --- a/configure.py +++ b/configure.py @@ -892,8 +892,8 @@ def set_clang_compiler_path_win(environ_cp): ) write_action_env_to_bazelrc('CLANG_COMPILER_PATH', clang_compiler_path) - write_to_bazelrc('build --repo_env=CC=%s' % clang_compiler_path) - write_to_bazelrc('build --repo_env=BAZEL_COMPILER=%s' % clang_compiler_path) + write_to_bazelrc(f'build --repo_env=CC="{clang_compiler_path}"') + write_to_bazelrc(f'build --repo_env=BAZEL_COMPILER="{clang_compiler_path}"') return clang_compiler_path diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt index 2335d295d0faf6..05dc3940487eef 100644 --- a/requirements_lock_3_10.txt +++ b/requirements_lock_3_10.txt @@ -249,9 +249,9 @@ h5py==3.10.0 \ # via # -r requirements.in # keras-nightly -idna==3.6 \ - --hash=sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca \ - --hash=sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f +idna==3.7 \ + --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ + --hash=sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0 # via requests jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 diff --git a/requirements_lock_3_11.txt b/requirements_lock_3_11.txt index 2335d295d0faf6..05dc3940487eef 100644 --- a/requirements_lock_3_11.txt +++ b/requirements_lock_3_11.txt @@ -249,9 +249,9 @@ h5py==3.10.0 \ # via # -r requirements.in # keras-nightly -idna==3.6 \ - --hash=sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca \ - --hash=sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f +idna==3.7 \ + --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ + --hash=sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0 # via requests jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 diff --git a/requirements_lock_3_12.txt b/requirements_lock_3_12.txt index 9bc6eff7313ec3..120ec6ebcd7c72 100644 --- a/requirements_lock_3_12.txt +++ b/requirements_lock_3_12.txt @@ -249,9 +249,9 @@ h5py==3.10.0 \ # via # -r requirements.in # keras-nightly -idna==3.6 \ - --hash=sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca \ - --hash=sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f +idna==3.7 \ + --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ + --hash=sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0 # via requests jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 diff --git a/requirements_lock_3_9.txt b/requirements_lock_3_9.txt index 9d9e85aceda9c7..36a55514cd788b 100644 --- a/requirements_lock_3_9.txt +++ b/requirements_lock_3_9.txt @@ -249,9 +249,9 @@ h5py==3.10.0 \ # via # -r requirements.in # keras-nightly -idna==3.6 \ - --hash=sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca \ - --hash=sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f +idna==3.7 \ + --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ + --hash=sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0 # via requests importlib-metadata==7.0.1 \ --hash=sha256:4805911c3a4ec7c3966410053e9ec6a1fecd629117df5adee56dfc9432a1081e \ diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 8c1e9d535d5cb1..71487e2aec0bee 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -1068,6 +1068,7 @@ package_group( "//third_party/py/keras/...", "//third_party/py/tf_keras/...", "//third_party/yggdrasil_decision_forests/...", + "//waymo/accelerator/...", "//waymo/ml/cn/...", "//waymo/ml/models/...", ], @@ -1116,9 +1117,10 @@ bzl_library( "@local_config_cuda//cuda:build_defs_bzl", "@local_config_rocm//rocm:build_defs_bzl", "@local_config_tensorrt//:build_defs_bzl", - "@local_tsl//tsl:tsl_bzl", "@local_tsl//tsl/platform/default:cuda_build_defs_bzl", + "@local_xla//xla/tsl:tsl_bzl", "@local_xla//xla/tsl/mkl:build_defs_bzl", + "@rules_java//java:rules", ], ) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 5490149bc905b1..aa4b5d6987871b 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -176,7 +176,7 @@ void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { } void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); if (s == nullptr) return; delete s->session; delete s; @@ -352,7 +352,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, TF_Status* status) { - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); for (int i = 0; i < noutputs; ++i) { c_outputs[i] = nullptr; } @@ -388,9 +388,9 @@ static Status TF_TensorToTensorV1(const TF_Tensor* src, Tensor* dst) { return InvalidArgument( "Malformed TF_RESOURCE tensor: unable to parse resource handle"); } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } static bool TF_Run_Inputs(TF_Tensor* const* c_inputs, @@ -959,7 +959,7 @@ void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc, TensorShapeProto shape; if (shape.ParseFromArray(proto, static_cast(proto_len))) { desc->node_builder.Attr(attr_name, shape); - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); } else { status->status = InvalidArgument("Unparseable TensorShapeProto"); } @@ -986,7 +986,7 @@ void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc, } } desc->node_builder.Attr(attr_name, shapes); - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); } void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name, @@ -999,7 +999,7 @@ void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name, void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, TF_Tensor* const* values, int num_values, TF_Status* status) { - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); std::vector t; t.reserve(num_values); @@ -1037,7 +1037,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, desc->node_builder.Attr(attr_name, std::move(attr_value)); } - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); } TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, @@ -1552,7 +1552,7 @@ void TF_OperationGetAttrName(TF_Operation* oper, int i, char* output, for (it = attrs.begin(); it != attrs.end(); it++) { if (count == i) { strncpy(output, it->first.c_str(), it->first.length()); - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); return; } count++; @@ -1931,7 +1931,7 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph, for (const auto& pair : results.return_tensors) { return_nodes->emplace_back(pair.first, pair.second); } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) { @@ -2063,7 +2063,7 @@ void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status, scope.impl()->control_deps(), ¶ms->cond_output, /* nreturn_nodes */ 1, &cond_output)); *output = cond_output[0]; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); }; // 'body_fn' copies the body graph into the parent graph. @@ -2078,7 +2078,7 @@ void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status, &parent->refiner, params->body_inputs, inputs, scope.impl()->name(), scope.impl()->control_deps(), params->body_outputs, num_loop_vars, outputs)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); }; // Create the while loop using an internal scope. @@ -2312,7 +2312,7 @@ void TF_CloseSession(TF_Session* s, TF_Status* status) { } void TF_DeleteSession(TF_Session* s, TF_Status* status) { - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); if (s == nullptr) return; TF_Graph* const graph = s->graph; if (graph != nullptr) { @@ -2472,7 +2472,7 @@ TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) { status->status = InvalidArgument("Unparseable OpList"); return nullptr; } - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); return new TF_ApiDefMap(op_list); } diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 45697e20d1ea05..bedba2c51c6d39 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -501,7 +501,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type, tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({})); std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len); - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor)); } diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 2fd92bd7dc0546..25805954eff67c 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -44,7 +44,7 @@ Status ValidateNonRefOutput(const Node* node, int idx) { return IsRefType(dt) ? InvalidArgument("Output ", idx, " of node '", node->name(), "' has a reference type ", DataTypeString(dt)) - : OkStatus(); + : absl::OkStatus(); } // Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and @@ -83,7 +83,7 @@ Status ProcessInputs( indices.push_back(idx); } } - return OkStatus(); + return absl::OkStatus(); } // Converts `noutputs` and `outputs` into `outputs_tensors` and does various @@ -105,7 +105,7 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name, fn_name, "'"); output_tensors->emplace_back(node, idx); } - return OkStatus(); + return absl::OkStatus(); } // Populates `body_nodes` with the nodes that will become function's body. @@ -142,7 +142,7 @@ Status ComputeBodyNodes( body_nodes->push_back(node); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -294,7 +294,7 @@ int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func, func->record = new tensorflow::FunctionRecord(lib.function(i), {}, false); funcs[i] = func; } - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); return len; } @@ -315,7 +315,7 @@ TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len, TF_Function* func = new TF_Function(); func->record = new tensorflow::FunctionRecord(std::move(fdef), {}, false); - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); return func; } @@ -338,7 +338,7 @@ void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name, (*(fdef_or.value()->mutable_attr()))[string(attr_name)] = attr_value; - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); } void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name, diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 877e2f262fba44..14045bbc2daef4 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -249,7 +249,7 @@ void TestEncodeDecode(int line, const std::vector& data) { // Convert back to a C++ Tensor and ensure we get expected output. Tensor output; - ASSERT_EQ(OkStatus(), TF_TensorToTensor(dst, &output)) << line; + ASSERT_EQ(absl::OkStatus(), TF_TensorToTensor(dst, &output)) << line; ASSERT_EQ(src.NumElements(), output.NumElements()) << line; for (int64_t i = 0; i < src.NumElements(); ++i) { ASSERT_EQ(data[i], output.flat()(i)) << line; diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 3a15bb5ba41f7f..f4b480752c90c9 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -913,8 +913,8 @@ tf_cuda_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@local_tsl//tsl/distributed_runtime/coordination:coordination_service_agent", "@local_xla//xla/tsl/c:tsl_status_internal", + "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", ], alwayslink = 1, ) diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index a7eb7798f23dec..05e0cb1c5347df 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" #include "xla/tsl/c/tsl_status_internal.h" +#include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "tensorflow/core/common_runtime/composite_device.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/strcat.h" -#include "tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "tsl/framework/cancellation.h" using tensorflow::string; diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 12b8d0c77ea7bb..f7fa3b2491a40b 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -140,8 +140,8 @@ class GradientTape { // Returns whether any tensor in a list of tensors is being watched and has // a trainable dtype. - bool ShouldRecord(gtl::ArraySlice tensor_ids, - gtl::ArraySlice dtypes) const; + bool ShouldRecord(absl::Span tensor_ids, + absl::Span dtypes) const; // Adds this tensor to the list of watched tensors. // @@ -158,8 +158,8 @@ class GradientTape { // nullptr instead of building zeros when build_default_zeros_grads == true. void RecordOperation( const string& op_type, const std::vector& output_tensors, - gtl::ArraySlice input_tensor_id, - gtl::ArraySlice input_dtypes, + absl::Span input_tensor_id, + absl::Span input_dtypes, const std::function& backward_function_getter, const std::function& backward_function_deleter); @@ -174,8 +174,8 @@ class GradientTape { // is set to false. Status ComputeGradient( const VSpace& vspace, - const gtl::ArraySlice target_tensor_ids, - const gtl::ArraySlice source_tensor_ids, + const absl::Span target_tensor_ids, + const absl::Span source_tensor_ids, const std::unordered_map& sources_that_are_targets, gtl::ArraySlice output_gradients, absl::Span result, bool build_default_zeros_grads = true); @@ -283,8 +283,8 @@ class ForwardAccumulator { Status Accumulate( const string& op_type, const std::vector& input_tensors, const std::vector& output_tensors, - gtl::ArraySlice input_tensor_id, - gtl::ArraySlice input_dtypes, + absl::Span input_tensor_id, + absl::Span input_dtypes, const ForwardFunction* forward_function, const std::function& backward_function_getter, const std::function& backward_function_deleter); @@ -306,8 +306,8 @@ class ForwardAccumulator { // Indicates whether the forward accumulator should run on an operation with // the specified inputs and dtypes. - bool ShouldRecord(gtl::ArraySlice tensor_ids, - gtl::ArraySlice dtypes); + bool ShouldRecord(absl::Span tensor_ids, + absl::Span dtypes); // Temporarily push or pop transient state for this accumulator. // @@ -392,8 +392,8 @@ inline bool IsDtypeTrainable(DataType dtype) { template bool GradientTape::ShouldRecord( - gtl::ArraySlice tensor_ids, - gtl::ArraySlice dtypes) const { + absl::Span tensor_ids, + absl::Span dtypes) const { CHECK_EQ(tensor_ids.size(), dtypes.size()); for (int i = 0; i < tensor_ids.size(); ++i) { if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) { @@ -414,8 +414,8 @@ void GradientTape::Watch( template void GradientTape::RecordOperation( const string& op_type, const std::vector& output_tensors, - gtl::ArraySlice input_tensor_id, - gtl::ArraySlice input_dtypes, + absl::Span input_tensor_id, + absl::Span input_dtypes, const std::function& backward_function_getter, const std::function& backward_function_deleter) { if (!ShouldRecord(input_tensor_id, input_dtypes)) { @@ -530,7 +530,7 @@ struct BackpropInitialState { // are needed, are copied and returned in BackpropInitialState. template BackpropInitialState PrepareBackprop( - gtl::ArraySlice target, const TensorTape& tensor_tape, + absl::Span target, const TensorTape& tensor_tape, OpTape* op_tape, const std::unordered_set& sources_set, bool persistent_tape) { std::vector tensor_stack; @@ -605,7 +605,7 @@ std::vector InitialStack( template Status InitialGradients( const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, + absl::Span target_tensor_ids, const std::unordered_map& sources_that_are_targets, gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, const OpTape& op_tape, @@ -690,8 +690,8 @@ constexpr int kMinAggregateBytes = 128 * 1024 * 1024; template Status GradientTape::ComputeGradient( const VSpace& vspace, - const gtl::ArraySlice target_tensor_ids, - const gtl::ArraySlice source_tensor_ids, + const absl::Span target_tensor_ids, + const absl::Span source_tensor_ids, const std::unordered_map& sources_that_are_targets, gtl::ArraySlice output_gradients, absl::Span result, bool build_default_zeros_grads) { @@ -907,8 +907,8 @@ Status GradientTape::ComputeGradient( template bool ForwardAccumulator::ShouldRecord( - gtl::ArraySlice tensor_ids, - gtl::ArraySlice dtypes) { + absl::Span tensor_ids, + absl::Span dtypes) { if (call_state_.top().backward_tape != nullptr) { // If we're forwarding Accumulate calls to backward_tape's RecordOperation, // we should also delegate ShouldRecord. @@ -1031,8 +1031,8 @@ template Status ForwardAccumulator::Accumulate( const string& op_type, const std::vector& input_tensors, const std::vector& output_tensors, - gtl::ArraySlice input_tensor_id, - gtl::ArraySlice input_dtypes, + absl::Span input_tensor_id, + absl::Span input_dtypes, const ForwardFunction* forward_function, const std::function& backward_function_getter, const std::function& backward_function_deleter) { diff --git a/tensorflow/c/experimental/gradients/nn_grad.cc b/tensorflow/c/experimental/gradients/nn_grad.cc index d249ed98944758..9d28a6d5cc4714 100644 --- a/tensorflow/c/experimental/gradients/nn_grad.cc +++ b/tensorflow/c/experimental/gradients/nn_grad.cc @@ -26,7 +26,6 @@ limitations under the License. using std::vector; using tensorflow::ops::BiasAddGrad; -using tensorflow::ops::Mul; using tensorflow::ops::ReluGrad; namespace tensorflow { diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD index 3d92b7ad3d2992..56586f757f369b 100644 --- a/tensorflow/c/experimental/next_pluggable_device/BUILD +++ b/tensorflow/c/experimental/next_pluggable_device/BUILD @@ -31,11 +31,11 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/distributed_runtime/coordination:coordination_service_agent", "@local_xla//xla/pjrt:pjrt_c_api_client", "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/pjrt/c:pjrt_c_api_hdrs", "@local_xla//xla/pjrt/c:pjrt_c_api_helpers", + "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", ], ) diff --git a/tensorflow/c/experimental/next_pluggable_device/c_api.cc b/tensorflow/c/experimental/next_pluggable_device/c_api.cc index a2ad1977a7ef3d..15a50a0a7c4060 100644 --- a/tensorflow/c/experimental/next_pluggable_device/c_api.cc +++ b/tensorflow/c/experimental/next_pluggable_device/c_api.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "tensorflow/core/common_runtime/next_pluggable_device/plugin_resource.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_handle.h" @@ -51,7 +52,6 @@ limitations under the License. #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/tfrt/common/pjrt_util.h" -#include "tsl/distributed_runtime/coordination/coordination_service_agent.h" TF_Device* TF_GetDevice(TF_OpKernelContext* ctx) { auto* cc_ctx = reinterpret_cast(ctx); diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD b/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD index c13bc899f2d016..7b62bd72c56903 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD @@ -36,10 +36,12 @@ tf_cc_tests( ), deps = [ ":renderers", + "//tensorflow/c/experimental/ops/gen/common", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/platform:types", ], ) diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index 22cf7275c6efa6..5dcb4a37c7af1d 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -45,6 +45,8 @@ cc_library( "//tensorflow/core/common_runtime/device:device_utils", "//tensorflow/core/platform:strcat", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:status", "@local_xla//xla/stream_executor", "@local_xla//xla/stream_executor:platform", @@ -67,7 +69,9 @@ cc_library( "//tensorflow/c:tf_status_helper", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/stream_executor", - "@local_xla//xla/stream_executor:stream_executor_internal", + "@local_xla//xla/stream_executor:event_interface", + "@local_xla//xla/stream_executor:stream_executor_interface", + "@local_xla//xla/stream_executor:stream_interface", ], ) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 1efbff9241d732..93d07b431ee4cf 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -21,15 +21,20 @@ limitations under the License. // device. #include "tensorflow/c/experimental/stream_executor/stream_executor.h" +#include #include #include #include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/c_api_macros_internal.h" #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" #include "tensorflow/c/tf_status_helper.h" #include "xla/stream_executor/executor_cache.h" +#include "xla/stream_executor/host_memory_allocation.h" +#include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" @@ -61,7 +66,7 @@ absl::Status ValidateSPPlatform(const SP_Platform& platform) { TF_RETURN_IF_ERROR( tensorflow::device_utils::ValidateDeviceType(platform.type)); // `visible_device_count` could be 0 at initialization time. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } absl::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) { @@ -73,33 +78,33 @@ absl::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) { TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, destroy_stream_executor); TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, create_device_fns); TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, destroy_device_fns); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } absl::Status ValidateSPAllocatorStats(const SP_AllocatorStats& stats) { TF_VALIDATE_STRUCT_SIZE(SP_AllocatorStats, stats, SP_ALLOCATORSTATS_STRUCT_SIZE); // All other fields could theoretically be zero/null. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } absl::Status ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase& mem) { TF_VALIDATE_STRUCT_SIZE(SP_DeviceMemoryBase, mem, SP_DEVICE_MEMORY_BASE_STRUCT_SIZE); // All other fields could theoretically be zero/null. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } absl::Status ValidateSPDevice(const SP_Device& device) { TF_VALIDATE_STRUCT_SIZE(SP_Device, device, SP_DEVICE_STRUCT_SIZE); // All other fields could theoretically be zero/null. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } absl::Status ValidateSPDeviceFns(const SP_DeviceFns& device_fns) { TF_VALIDATE_STRUCT_SIZE(SP_DeviceFns, device_fns, SP_DEVICE_FNS_STRUCT_SIZE); // All other fields could theoretically be zero/null. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } absl::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se, @@ -135,7 +140,7 @@ absl::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se, TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, mem_zero); TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, memset); TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, memset32); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } absl::Status ValidateSEPlatformRegistrationParams( @@ -145,7 +150,7 @@ absl::Status ValidateSEPlatformRegistrationParams( TF_VALIDATE_NOT_NULL(SE_PlatformRegistrationParams, params, destroy_platform); TF_VALIDATE_NOT_NULL(SE_PlatformRegistrationParams, params, destroy_platform_fns); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } #undef TF_VALIDATE_NOT_NULL @@ -195,7 +200,7 @@ void HostCallbackTrampoline(void* ctx, TF_Status* status) { delete host_ctx; } -class CStreamExecutor : public internal::StreamExecutorInterface { +class CStreamExecutor : public StreamExecutorInterface { public: explicit CStreamExecutor(SP_Device device, SP_DeviceFns* device_fns, SP_StreamExecutor* stream_executor, @@ -215,9 +220,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface { platform_fns_->destroy_device(platform_, &device_); } - absl::Status Init(int device_ordinal) override { - return ::tensorflow::OkStatus(); - } + absl::Status Init() override { return absl::OkStatus(); } DeviceMemoryBase Allocate(uint64 size, int64_t memory_space) override { SP_DeviceMemoryBase mem = {SP_DEVICE_MEMORY_BASE_STRUCT_SIZE}; @@ -237,17 +240,20 @@ class CStreamExecutor : public internal::StreamExecutorInterface { stream_executor_->deallocate(&device_, &device_memory_base); } - void* HostMemoryAllocate(uint64 size) override { - return stream_executor_->host_memory_allocate(&device_, size); + absl::StatusOr> HostMemoryAllocate( + uint64 size) override { + auto* buffer = stream_executor_->host_memory_allocate(&device_, size); + if (buffer == nullptr && size > 0) { + return absl::InternalError( + absl::StrFormat("Failed to allocate HostMemory of size %d", size)); + } + return std::make_unique(buffer, size, this); } void HostMemoryDeallocate(void* mem) override { stream_executor_->host_memory_deallocate(&device_, mem); } - bool HostMemoryRegister(void* mem, uint64 size) override { return false; } - bool HostMemoryUnregister(void* mem) override { return false; } - void* UnifiedMemoryAllocate(uint64 size) override { CHECK(stream_executor_->unified_memory_allocate); return stream_executor_->unified_memory_allocate(&device_, size); @@ -302,11 +308,6 @@ class CStreamExecutor : public internal::StreamExecutorInterface { return tsl::errors::Unimplemented( "SynchronousMemZero is not supported by pluggable device."); } - absl::Status SynchronousMemSet(DeviceMemoryBase* location, int value, - uint64 size) override { - return tsl::errors::Unimplemented( - "SynchronousMemSet is not supported by pluggable device."); - } absl::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst, const void* host_src, uint64 size) override { OwnedTFStatus c_status(TF_NewStatus()); @@ -324,16 +325,6 @@ class CStreamExecutor : public internal::StreamExecutorInterface { size, c_status.get()); return StatusFromTF_Status(c_status.get()); } - absl::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, - uint64 size) override { - OwnedTFStatus c_status(TF_NewStatus()); - SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst); - SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src); - stream_executor_->sync_memcpy_dtod(&device_, &device_mem_dst, - &device_mem_src, size, c_status.get()); - return StatusFromTF_Status(c_status.get()); - } absl::Status MemZero(Stream* stream, DeviceMemoryBase* location, uint64 size) override { OwnedTFStatus c_status(TF_NewStatus()); @@ -420,7 +411,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface { } absl::Status DeallocateEvent(Event* event) override { static_cast(event->implementation())->Destroy(); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } absl::Status RecordEvent(Stream* stream, Event* event) override { SP_Stream stream_handle = @@ -568,14 +559,12 @@ class CStreamExecutor : public internal::StreamExecutorInterface { // Each call creates a new instance of the platform-specific implementation of // the corresponding interface type. - std::unique_ptr CreateEventImplementation() - override { - return std::unique_ptr( + std::unique_ptr CreateEventImplementation() override { + return std::unique_ptr( new CEvent(&device_, stream_executor_)); } - std::unique_ptr GetStreamImplementation() - override { - return std::unique_ptr( + std::unique_ptr GetStreamImplementation() override { + return std::unique_ptr( new CStream(&device_, stream_executor_)); } @@ -655,11 +644,10 @@ absl::StatusOr> CPlatform::GetUncachedExecutor( c_status.get()); TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get())); - auto executor = absl::make_unique( + auto executor = std::make_unique( std::move(device), &device_fns_, &stream_executor_, &platform_, &platform_fns_, &timer_fns_, name_, visible_device_count); - auto result = absl::make_unique(this, std::move(executor), - config.ordinal); + auto result = std::make_unique(this, std::move(executor)); return result; } @@ -735,6 +723,6 @@ absl::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, stream_executor::PlatformManager::RegisterPlatform(std::move(cplatform))); // TODO(annarev): Return `use_bfc_allocator` value in some way so that it is // available in `PluggableDeviceProcessState` once the latter is checked in. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace stream_executor diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h index ad193a045cba50..48ea2ccf26d6f9 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -20,10 +20,12 @@ limitations under the License. #include "tensorflow/c/experimental/stream_executor/stream_executor.h" #include "tensorflow/c/tf_status_helper.h" +#include "xla/stream_executor/event_interface.h" #include "xla/stream_executor/executor_cache.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/stream_executor_internal.h" +#include "xla/stream_executor/stream_executor_interface.h" +#include "xla/stream_executor/stream_interface.h" #include "tsl/platform/statusor.h" namespace stream_executor { @@ -35,14 +37,15 @@ typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const, // Registers StreamExecutor platform. `device_type` and `platform_name` are // output parameters. -tsl::Status InitStreamExecutorPlugin(void* dso_handle, std::string* device_type, - std::string* platform_name); +absl::Status InitStreamExecutorPlugin(void* dso_handle, + std::string* device_type, + std::string* platform_name); // Allow registering a StreamExecutor plugin using a function (used for // testing). -tsl::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, - std::string* device_type, - std::string* platform_name); +absl::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, + std::string* device_type, + std::string* platform_name); // This file implements core stream executor base classes in terms of // the C API defined in stream_executor.h. A class "CSomething" represents a @@ -72,12 +75,12 @@ class CPlatform : public Platform { } bool UseBfcAllocator() const { return platform_.use_bfc_allocator; } bool ForceMemoryGrowth() const { return platform_.force_memory_growth; } - tsl::StatusOr> DescriptionForDevice( + absl::StatusOr> DescriptionForDevice( int ordinal) const override; - tsl::StatusOr ExecutorForDevice(int ordinal) override; - tsl::StatusOr GetExecutor( + absl::StatusOr ExecutorForDevice(int ordinal) override; + absl::StatusOr GetExecutor( const StreamExecutorConfig& config) override; - tsl::StatusOr> GetUncachedExecutor( + absl::StatusOr> GetUncachedExecutor( const StreamExecutorConfig& config) override; void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); } @@ -95,7 +98,7 @@ class CPlatform : public Platform { stream_executor::ExecutorCache executor_cache_; }; -class CStream : public internal::StreamInterface { +class CStream : public StreamInterface { public: CStream(SP_Device* device, SP_StreamExecutor* stream_executor) : device_(device), @@ -103,10 +106,10 @@ class CStream : public internal::StreamInterface { stream_handle_(nullptr) {} ~CStream() override { Destroy(); } - tsl::Status Create() { + absl::Status Create() { tensorflow::TF_StatusPtr c_status(TF_NewStatus()); stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); - tsl::Status s = tensorflow::StatusFromTF_Status(c_status.get()); + absl::Status s = tensorflow::StatusFromTF_Status(c_status.get()); return s; } @@ -125,7 +128,7 @@ class CStream : public internal::StreamInterface { SP_Stream stream_handle_; }; -class CEvent : public internal::EventInterface { +class CEvent : public EventInterface { public: CEvent(SP_Device* device, SP_StreamExecutor* stream_executor) : device_(device), @@ -133,13 +136,13 @@ class CEvent : public internal::EventInterface { event_handle_(nullptr) {} ~CEvent() override { Destroy(); } - tsl::Status Create() { + absl::Status Create() { tensorflow::TF_StatusPtr c_status(TF_NewStatus()); stream_executor_->create_event(device_, &event_handle_, c_status.get()); return tensorflow::StatusFromTF_Status(c_status.get()); } - tsl::Status Record(SP_Stream stream_handle) { + absl::Status Record(SP_Stream stream_handle) { tensorflow::TF_StatusPtr c_status(TF_NewStatus()); stream_executor_->record_event(device_, stream_handle, event_handle_, c_status.get()); diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc index e4dda6c0a6c177..56f25a5811293e 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc @@ -39,17 +39,17 @@ TEST(StreamExecutor, SuccessfulRegistration) { test_util::PopulateDefaultPlatformRegistrationParams(params); }; std::string device_type, platform_name; - tsl::Status status = + absl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); TF_ASSERT_OK(status); - tsl::StatusOr maybe_platform = + absl::StatusOr maybe_platform = PlatformManager::PlatformWithName("MY_DEVICE"); TF_ASSERT_OK(maybe_platform.status()); Platform* platform = std::move(maybe_platform).value(); ASSERT_EQ(platform->Name(), test_util::kDeviceName); ASSERT_EQ(platform->VisibleDeviceCount(), test_util::kDeviceCount); - tsl::StatusOr maybe_executor = + absl::StatusOr maybe_executor = platform->ExecutorForDevice(0); TF_ASSERT_OK(maybe_executor.status()); } @@ -63,7 +63,7 @@ TEST(StreamExecutor, NameNotSet) { }; std::string device_type, platform_name; - tsl::Status status = + absl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ(status.message(), "'name' field in SP_Platform must be set."); @@ -78,7 +78,7 @@ TEST(StreamExecutor, InvalidNameWithSemicolon) { }; std::string device_type, platform_name; - tsl::Status status = + absl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); EXPECT_THAT( @@ -95,7 +95,7 @@ TEST(StreamExecutor, InvalidNameWithSlash) { }; std::string device_type, platform_name; - tsl::Status status = + absl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); EXPECT_THAT(status.message(), @@ -111,7 +111,7 @@ TEST(StreamExecutor, CreateDeviceNotSet) { }; std::string device_type, platform_name; - tsl::Status status = + absl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ(status.message(), @@ -127,7 +127,7 @@ TEST(StreamExecutor, UnifiedMemoryAllocateNotSet) { }; std::string device_type, platform_name; - tsl::Status status = + absl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ( @@ -153,7 +153,7 @@ class StreamExecutorTest : public ::testing::Test { platform_, test_util::DestroyPlatform, platform_fns_, test_util::DestroyPlatformFns, device_fns_, se_, timer_fns_); } - tsl::StatusOr maybe_executor = + absl::StatusOr maybe_executor = cplatform_->ExecutorForDevice(ordinal); TF_CHECK_OK(maybe_executor.status()); return std::move(maybe_executor).value(); @@ -185,7 +185,6 @@ TEST_F(StreamExecutorTest, Allocate) { ASSERT_NE(mem.opaque(), nullptr); ASSERT_EQ(mem.size(), 2 * sizeof(int)); executor->Deallocate(&mem); - ASSERT_EQ(mem.opaque(), nullptr); } TEST_F(StreamExecutorTest, HostMemoryAllocate) { @@ -515,25 +514,6 @@ TEST_F(StreamExecutorTest, SyncMemcpyFromHost) { ASSERT_EQ(dst_data, 18); } -TEST_F(StreamExecutorTest, SyncMemcpyDeviceToDevice) { - se_.sync_memcpy_dtod = [](const SP_Device* const device, - SP_DeviceMemoryBase* const device_dst, - const SP_DeviceMemoryBase* const device_src, - uint64_t size, TF_Status* const status) { - TF_SetStatus(status, TF_OK, ""); - std::memcpy(device_dst->opaque, device_src->opaque, size); - }; - - StreamExecutor* executor = GetExecutor(0); - size_t size = sizeof(int); - int src_data = 18; - int dst_data = 0; - DeviceMemoryBase device_dst(&dst_data, size); - DeviceMemoryBase device_src(&src_data, size); - ASSERT_TRUE(executor->SynchronousMemcpy(&device_dst, device_src, size)); - ASSERT_EQ(dst_data, 18); -} - TEST_F(StreamExecutorTest, BlockHostForEvent) { static bool block_host_for_event_called = false; se_.create_event = [](const SP_Device* const device, SP_Event* event, @@ -625,7 +605,7 @@ TEST_F(StreamExecutorTest, HostCallbackError) { }; StreamExecutor* executor = GetExecutor(0); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); - std::function callback = []() -> tsl::Status { + std::function callback = []() -> absl::Status { return tsl::errors::Unimplemented("Unimplemented"); }; ASSERT_FALSE(stream->DoHostCallbackWithStatus(callback).ok()); diff --git a/tensorflow/c/kernels_experimental.cc b/tensorflow/c/kernels_experimental.cc index 09ce84d42f7392..26173507f29aec 100644 --- a/tensorflow/c/kernels_experimental.cc +++ b/tensorflow/c/kernels_experimental.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/kernels_experimental.h" #include +#include #include #include #include @@ -74,7 +75,7 @@ tensorflow::Status EnsureSparseVariableAccess( tensorflow::Var* var, bool lock_held = false) { auto* context = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); if (var->copy_on_read_mode.load()) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } std::optional ml; @@ -87,7 +88,7 @@ tensorflow::Status EnsureSparseVariableAccess( // copy-on-read mode is false. if (var->tensor()->RefCountIsOne()) { var->copy_on_read_mode.store(true); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } Tensor tmp; if (variantType) { @@ -114,7 +115,7 @@ tensorflow::Status EnsureSparseVariableAccess( } *var->tensor() = tmp; var->copy_on_read_mode.store(true); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } tensorflow::Status PrepareToUpdateVariable( @@ -151,7 +152,7 @@ tensorflow::Status PrepareToUpdateVariable( } *tensor = tmp; } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } tensorflow::mutex* GetTrainingVariableMutex(TF_OpKernelContext* ctx, @@ -186,7 +187,7 @@ void TF_AssignVariable(TF_OpKernelContext* ctx, int input_index, *ptr = new tensorflow::Var(value.dtype()); *(*ptr)->tensor() = value; (*ptr)->is_initialized = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); })); tensorflow::mutex_lock ml(*variable->mu()); @@ -414,9 +415,9 @@ void TF_MaybeLockVariableInputMutexesInOrder( std::sort(acquire_order.begin(), acquire_order.end(), [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; }); - auto locks = absl::make_unique>(); + auto locks = std::make_unique>(); auto shared_locks = - absl::make_unique>(); + std::make_unique>(); locks->reserve(acquire_order.size()); for (auto acquire : acquire_order) { @@ -565,7 +566,7 @@ static Status ValidateVariantType(const Variant& variant) { type_index_name); } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } static Status VariantBinaryAddFunc( @@ -581,11 +582,11 @@ static Status CCBinaryAddFunc( TF_Tensor* out)) { if (cc_a.dtype() == ::tensorflow::DT_INVALID) { *cc_out = cc_b; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (cc_b.dtype() == ::tensorflow::DT_INVALID) { *cc_out = cc_a; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } Status status; diff --git a/tensorflow/c/tf_buffer.cc b/tensorflow/c/tf_buffer.cc index 864a9e79818db9..a891f89ed16d0c 100644 --- a/tensorflow/c/tf_buffer.cc +++ b/tensorflow/c/tf_buffer.cc @@ -78,7 +78,7 @@ Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, out->data = buf; out->length = proto_size; out->data_deallocator = [](void* data, size_t length) { port::Free(data); }; - return OkStatus(); + return absl::OkStatus(); } Status BufferToMessage(const TF_Buffer* in, @@ -87,7 +87,7 @@ Status BufferToMessage(const TF_Buffer* in, return errors::InvalidArgument("Unparseable ", out->GetTypeName(), " proto"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/c/tf_status_helper.cc b/tensorflow/c/tf_status_helper.cc index bb07a9213b4256..bbeae6f76bc497 100644 --- a/tensorflow/c/tf_status_helper.cc +++ b/tensorflow/c/tf_status_helper.cc @@ -22,7 +22,8 @@ limitations under the License. namespace tsl { -void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status) { +void Set_TF_Status_from_Status(TF_Status* tf_status, + const absl::Status& status) { TF_SetStatus(tf_status, TSLCodeFromStatusCode(status.code()), tsl::NullTerminatedMessage(status)); status.ForEachPayload( @@ -33,13 +34,13 @@ void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status) { }); } -Status StatusFromTF_Status(const TF_Status* tf_status) { - Status status(StatusCodeFromTSLCode(TF_GetCode(tf_status)), - TF_Message(tf_status)); +absl::Status StatusFromTF_Status(const TF_Status* tf_status) { + absl::Status status(StatusCodeFromTSLCode(TF_GetCode(tf_status)), + TF_Message(tf_status)); TF_ForEachPayload( tf_status, [](const char* key, const char* value, void* capture) { - Status* status = static_cast(capture); + absl::Status* status = static_cast(capture); status->SetPayload(key, absl::Cord(absl::string_view(value))); }, &status); diff --git a/tensorflow/c/tf_status_helper.h b/tensorflow/c/tf_status_helper.h index 0f5c2faa6a0a65..ce833c394cb01b 100644 --- a/tensorflow/c/tf_status_helper.h +++ b/tensorflow/c/tf_status_helper.h @@ -24,10 +24,11 @@ limitations under the License. namespace tsl { // Set the attribute of "tf_status" from the attributes of "status". -void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status); +void Set_TF_Status_from_Status(TF_Status* tf_status, + const absl::Status& status); // Returns a "status" from "tf_status". -Status StatusFromTF_Status(const TF_Status* tf_status); +absl::Status StatusFromTF_Status(const TF_Status* tf_status); } // namespace tsl namespace tensorflow { diff --git a/tensorflow/c/tf_status_helper_test.cc b/tensorflow/c/tf_status_helper_test.cc index e99c64d68d335d..653395437821c3 100644 --- a/tensorflow/c/tf_status_helper_test.cc +++ b/tensorflow/c/tf_status_helper_test.cc @@ -23,14 +23,14 @@ namespace { TEST(StatusHelper, TestStatusHelper) { TSL_Status* s = TSL_NewStatus(); - Status cc_status(absl::InvalidArgumentError("some error")); + absl::Status cc_status(absl::InvalidArgumentError("some error")); cc_status.SetPayload("key1", absl::Cord("value1")); cc_status.SetPayload("key2", absl::Cord("value2")); Set_TF_Status_from_Status(s, cc_status); ASSERT_EQ(TSL_INVALID_ARGUMENT, TSL_GetCode(s)); ASSERT_EQ(std::string("some error"), TSL_Message(s)); - Status another_cc_status(StatusFromTF_Status(s)); + absl::Status another_cc_status(StatusFromTF_Status(s)); ASSERT_FALSE(another_cc_status.ok()); ASSERT_EQ(std::string("some error"), another_cc_status.message()); ASSERT_EQ(error::INVALID_ARGUMENT, another_cc_status.code()); diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc index 701c6fe825c36a..96c3fd97344115 100644 --- a/tensorflow/c/tf_tensor.cc +++ b/tensorflow/c/tf_tensor.cc @@ -260,7 +260,7 @@ Status TensorInterface::BitcastFrom(const TensorInterface& from, DataType type, Status TensorInterface::FromProto(const tensorflow::TensorProto& from) { bool success = tensor_.FromProto(from); - if (success) return OkStatus(); + if (success) return absl::OkStatus(); return errors::InvalidArgument("Unparseable tensor proto"); } @@ -296,7 +296,7 @@ namespace tensorflow { AbstractTensorInterface* TensorInterfaceFromTensor(const Tensor& src, Status* status) { - *status = OkStatus(); + *status = absl::OkStatus(); if (!src.IsInitialized()) { *status = FailedPrecondition( "attempt to use a tensor with an uninitialized value"); @@ -324,7 +324,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) { TF_Tensor* TF_TensorFromTensorShallow(const tensorflow::Tensor& src, Status* status) { - *status = OkStatus(); + *status = absl::OkStatus(); if (!src.IsInitialized()) { *status = FailedPrecondition( "attempt to use a tensor with an uninitialized value"); @@ -343,7 +343,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { Status TensorInterface::ToTensor(tensorflow::Tensor* dst) const { *dst = tensor_; - return OkStatus(); + return absl::OkStatus(); } bool TensorInterface::IsAligned() const { return tensor_.IsAligned(); } diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc index 687c73d8b4e495..1bb04fec430b23 100644 --- a/tensorflow/c/while_loop_test.cc +++ b/tensorflow/c/while_loop_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/core/platform/logging.h" @@ -45,8 +47,8 @@ class CApiWhileLoopTest : public ::testing::Test { original_graph_description_ = GraphDebugString(); - params_.reset(new TF_WhileParams( - TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_))); + params_ = std::make_unique( + TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_)); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); ASSERT_EQ(original_graph_description_, GraphDebugString()) << "TF_NewWhile() altered graph"; @@ -85,7 +87,7 @@ class CApiWhileLoopTest : public ::testing::Test { ++i; } // TODO(skyewm): use std::make_unique or absl::make_unique when possible. - csession_.reset(new CSession(graph_, s_)); + csession_ = std::make_unique(graph_, s_); csession_->SetInputs(inputs); csession_->SetOutputs(run_outputs); csession_->Run(s_); diff --git a/tensorflow/cc/experimental/base/tests/tensor_test.cc b/tensorflow/cc/experimental/base/tests/tensor_test.cc index 33f9ab637e82a5..95610f098cc470 100644 --- a/tensorflow/cc/experimental/base/tests/tensor_test.cc +++ b/tensorflow/cc/experimental/base/tests/tensor_test.cc @@ -82,7 +82,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) { EXPECT_EQ(tensor.dims(), 1); EXPECT_EQ(tensor.dtype(), dtype); - tensorflow::gtl::ArraySlice tensor_view( + absl::Span tensor_view( reinterpret_cast(tensor.data()), value.size()); EXPECT_EQ(tensor_view[0], 42); EXPECT_EQ(tensor_view[1], 100); @@ -121,7 +121,7 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) { EXPECT_EQ(tensor.dims(), 2); EXPECT_EQ(tensor.dtype(), dtype); - tensorflow::gtl::ArraySlice tensor_view( + absl::Span tensor_view( reinterpret_cast(tensor.data()), value.size()); EXPECT_EQ(tensor_view[0], 42); EXPECT_EQ(tensor_view[1], 100); diff --git a/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc index cfeaba4e3923ca..77ac7052baa0fe 100644 --- a/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc +++ b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc @@ -116,7 +116,7 @@ TYPED_TEST(Construct1DTensorHandleTest, EXPECT_EQ(tensor.dims(), 1); EXPECT_EQ(tensor.dtype(), dtype); - tensorflow::gtl::ArraySlice tensor_view( + absl::Span tensor_view( reinterpret_cast(tensor.data()), value.size()); EXPECT_EQ(tensor_view[0], 42); EXPECT_EQ(tensor_view[1], 100); @@ -166,7 +166,7 @@ TYPED_TEST(Construct2DTensorHandleTest, EXPECT_EQ(tensor.dims(), 2); EXPECT_EQ(tensor.dtype(), dtype); - tensorflow::gtl::ArraySlice tensor_view( + absl::Span tensor_view( reinterpret_cast(tensor.data()), value.size()); EXPECT_EQ(tensor_view[0], 42); EXPECT_EQ(tensor_view[1], 100); diff --git a/tensorflow/cc/experimental/libtf/value.h b/tensorflow/cc/experimental/libtf/value.h index c8347e6c3033d7..61a2888426ee3d 100644 --- a/tensorflow/cc/experimental/libtf/value.h +++ b/tensorflow/cc/experimental/libtf/value.h @@ -56,7 +56,7 @@ using Dict = using DictPtr = std::shared_ptr; using TuplePtr = std::shared_ptr; using Func = - std::function(TaggedValue, TaggedValue)>; + std::function(TaggedValue, TaggedValue)>; // A capsule holds a pointer and a destructor for the pointer (i.e. a generic // shared_ptr to void with a custom deleter). using Capsule = std::shared_ptr; diff --git a/tensorflow/cc/framework/grad_op_registry.cc b/tensorflow/cc/framework/grad_op_registry.cc index 268ea764de8a4c..26628759277889 100644 --- a/tensorflow/cc/framework/grad_op_registry.cc +++ b/tensorflow/cc/framework/grad_op_registry.cc @@ -41,7 +41,7 @@ Status GradOpRegistry::Lookup(const string& op, GradFunc* func) const { return errors::NotFound(error_msg); } *func = iter->second; - return OkStatus(); + return absl::OkStatus(); } } // end namespace ops diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc index 0c026cf9a0c2c5..90f104bc24b129 100644 --- a/tensorflow/cc/framework/gradient_checker.cc +++ b/tensorflow/cc/framework/gradient_checker.cc @@ -183,7 +183,7 @@ Status ComputeTheoreticalJacobianTranspose( } } } - return OkStatus(); + return absl::OkStatus(); } Status EvaluateGraph(ClientSession* session, const OutputList& xs, @@ -208,7 +208,7 @@ Status EvaluateGraph(ClientSession* session, const OutputList& xs, } } } - return OkStatus(); + return absl::OkStatus(); } template @@ -272,7 +272,7 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs, } } } - return OkStatus(); + return absl::OkStatus(); } // The Jacobian is always a real-valued matrix. @@ -366,13 +366,13 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs, // (Note that std::max may ignore NaN arguments.) if (std::isnan(cur_error)) { *max_error = cur_error; - return OkStatus(); + return absl::OkStatus(); } *max_error = std::max(*max_error, cur_error); } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index e0a399d6b1c0da..548f5c04833a2e 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -166,7 +166,7 @@ Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad, ready_.push_back(src.node()); } } - return OkStatus(); + return absl::OkStatus(); } std::vector SymbolicGradientBuilder::GetReachableNodes() { @@ -341,7 +341,7 @@ Status SymbolicGradientBuilder::Initialize() { TF_RETURN_IF_ERROR(BackpropAlongEdge(grad_inputs_[i], outputs_[i])); } } - return OkStatus(); + return absl::OkStatus(); } Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) { @@ -372,7 +372,7 @@ Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) { *grad = ops::AddN(scope_, grads_to_keep); } - return OkStatus(); + return absl::OkStatus(); } bool SymbolicGradientBuilder::IsPrimitiveOpWithNoGrad(const string& opname) { @@ -388,7 +388,7 @@ Status SymbolicGradientBuilder::CallGradFunction( TF_RETURN_IF_ERROR(registry_->Lookup(op.node()->type_string(), &grad_fn)); TF_RETURN_IF_ERROR(grad_fn(scope_, op, grad_inputs, grad_outputs)); TF_RETURN_IF_ERROR(scope_.status()); - return OkStatus(); + return absl::OkStatus(); } Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node, @@ -414,7 +414,8 @@ Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node, // Wait until we have all exit nodes' backprops collected before processing // the while loop. // TODO(skyewm): what if not all the exit nodes are reachable? - if (backprops.size() < while_ctx->exit_nodes().size()) return OkStatus(); + if (backprops.size() < while_ctx->exit_nodes().size()) + return absl::OkStatus(); // We've seen all the exit nodes for this loop and have collected all the // backprops. Create the gradient graph for the while loop. @@ -435,7 +436,7 @@ Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node, TF_RETURN_IF_ERROR(BackpropAlongEdge(dx[i], {e->src(), e->src_output()})); } } - return OkStatus(); + return absl::OkStatus(); } Status SymbolicGradientBuilder::AddGradients() { @@ -553,7 +554,7 @@ Status SymbolicGradientBuilder::AddGradients() { int num_requested_inputs = p.first->num_outputs() - pending_[p.first->id()]; CHECK_EQ(num_requested_inputs, p.second); } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index ab8b387ab5681a..7bbb3b2bcb5236 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -221,7 +221,7 @@ class Input { tensor_(init.tensor) {} Input(const Tensor& t) // NOLINT(runtime/explicit) - : status_(OkStatus()), tensor_(t) {} + : status_(absl::OkStatus()), tensor_(t) {} Input(const std::initializer_list& init) { // NOLINT(runtime/explicit) @@ -274,8 +274,7 @@ class InputList { const std::initializer_list& inputs) // NOLINT(runtime/explicit) : inputs_(inputs.begin(), inputs.end()) {} - InputList(const tensorflow::gtl::ArraySlice& - inputs) // NOLINT(runtime/explicit) + InputList(const absl::Span& inputs) // NOLINT(runtime/explicit) : inputs_(inputs.begin(), inputs.end()) {} InputList( diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 6667b6919d52e6..0c972612089918 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -311,7 +311,7 @@ Status Scope::ToGraphDef(GraphDef* gdef, bool include_debug_info) const { return *impl()->status_; } graph()->ToGraphDef(gdef, /*include_flib_def=*/true, include_debug_info); - return OkStatus(); + return absl::OkStatus(); } Status Scope::ToGraph(Graph* g, GraphConstructorOptions opts) const { @@ -427,7 +427,7 @@ Scope Scope::WithOpNameImpl(const string& op_name) const { } Scope Scope::WithControlDependencies( - const gtl::ArraySlice control_deps) const { + const absl::Span control_deps) const { return Scope( new Impl(*this, Impl::Tags::ControlDeps(), std::vector(control_deps.begin(), control_deps.end()), @@ -499,7 +499,7 @@ CompositeOpScopes Scope::GetCompositeOpScopes( } Status Scope::DoShapeInference(Node* node) const { - if (impl_->disable_shape_inference_) return OkStatus(); + if (impl_->disable_shape_inference_) return absl::OkStatus(); return impl_->refiner_->AddNode(node); } @@ -547,7 +547,7 @@ Status CreateOutputWithScope(string op_name, scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); TF_RETURN_IF_ERROR(scope.status()); *output = Output(ret, 0); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 771fdaa11688c9..0b0f6871e7f27c 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -125,7 +125,7 @@ class Scope { /// Return a new scope. All ops created within the returned scope will have as /// control dependencies the union of operations in the control_deps vector /// and the control dependencies of the current scope. - Scope WithControlDependencies(gtl::ArraySlice control_deps) const; + Scope WithControlDependencies(absl::Span control_deps) const; /// Same as above, but convenient to add control dependency on the operation /// producing the control_dep output. Scope WithControlDependencies(const Output& control_dep) const; diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc index e28306e5a33031..9f966994ea2066 100644 --- a/tensorflow/cc/framework/while_gradients.cc +++ b/tensorflow/cc/framework/while_gradients.cc @@ -70,7 +70,7 @@ Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope, const std::vector& inputs, Output* output) { *output = ToOutput(while_ctx->cond_output()); - return OkStatus(); + return absl::OkStatus(); }; // Body function that adds one to input. @@ -88,7 +88,7 @@ Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope, while_ctx->frame_name(), &outputs, /* create_while_ctx */ false)); *count = outputs[0]; - return OkStatus(); + return absl::OkStatus(); } // Creates a loop that executes `loop_count` times. The returned output is the @@ -126,7 +126,7 @@ Status AddBackPropLoopCounter(WhileContext* while_ctx, const Output& loop_count, TF_RETURN_IF_ERROR(BuildWhileLoop( scope, {loop_count}, cond_fn, body_fn, frame_name, &outputs, /* create_while_ctx */ false, backprop_execution_pred)); - return OkStatus(); + return absl::OkStatus(); } // Creates the main backprop loop that computes the gradient of the loop @@ -155,7 +155,7 @@ Status AddWhileGradientLoop(WhileContext* while_ctx, const std::vector& inputs, Output* output) { *output = backprop_execution_pred; - return OkStatus(); + return absl::OkStatus(); }; // Body function that builds while body gradient subgraph. @@ -173,7 +173,7 @@ Status AddWhileGradientLoop(WhileContext* while_ctx, TF_RETURN_IF_ERROR(BuildWhileLoop(scope, grad_inputs, cond_fn, body_fn, frame_name, grad_outputs, /* create_while_ctx */ false)); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc index a54f3cb9f6d010..fdbce41c8ca1ef 100644 --- a/tensorflow/cc/training/coordinator.cc +++ b/tensorflow/cc/training/coordinator.cc @@ -45,7 +45,7 @@ Status Coordinator::RegisterRunner(std::unique_ptr runner) { } mutex_lock l(runners_lock_); runners_.push_back(std::move(runner)); - return OkStatus(); + return absl::OkStatus(); } bool Coordinator::AllRunnersStopped() { @@ -66,7 +66,7 @@ Status Coordinator::RequestStop() { } should_stop_ = true; wait_for_stop_.notify_all(); - return OkStatus(); + return absl::OkStatus(); } bool Coordinator::ShouldStop() { @@ -123,7 +123,7 @@ Status Coordinator::ExportCostGraph(CostGraphDef* cost_graph) const { return s; } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 3a68e981c24c63..e480ea29a8061b 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -77,7 +77,7 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { thread_pool_.reset(new thread::ThreadPool( Env::Default(), SanitizeThreadSuffix(queue_name_), nthreads)); - return OkStatus(); + return absl::OkStatus(); } QueueRunner::~QueueRunner() { @@ -118,7 +118,7 @@ Status QueueRunner::Start(Session* sess, int wait_for) { return status_; } } - return OkStatus(); + return absl::OkStatus(); } Status QueueRunner::StartAndCollectCostGraph(Session* session, int wait_for_ms, @@ -212,7 +212,7 @@ Status QueueRunner::ExportCostGraph(CostGraphDef* cost_graph) const { } mutex_lock l(*cg_mu_); cost_graph->MergeFrom(*cost_graph_); - return OkStatus(); + return absl::OkStatus(); } void QueueRunner::SetRunArgumentsAndCostGraph(const RunOptions& run_options) { diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 44f2622834fa4b..1edcfbf51432b7 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -4,6 +4,7 @@ load( "//tensorflow/core/platform:build_config_root.bzl", "if_llvm_aarch32_available", "if_llvm_aarch64_available", + "if_llvm_hexagon_available", "if_llvm_powerpc_available", "if_llvm_system_z_available", "if_llvm_x86_available", @@ -51,6 +52,8 @@ cc_library( compatible_with = [], defines = if_llvm_aarch32_available(["TF_LLVM_AARCH32_AVAILABLE=1"]) + if_llvm_aarch64_available([ "TF_LLVM_AARCH64_AVAILABLE=1", + ]) + if_llvm_hexagon_available([ + "TF_LLVM_HEXAGON_AVAILABLE=1", ]) + if_llvm_powerpc_available([ "TF_LLVM_POWERPC_AVAILABLE=1", ]) + if_llvm_system_z_available([ @@ -141,6 +144,9 @@ cc_library( ]) + if_llvm_aarch64_available([ "@llvm-project//llvm:AArch64AsmParser", # fixdeps: keep "@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep + ]) + if_llvm_hexagon_available([ + "@llvm-project//llvm:HexagonAsmParser", # fixdeps: keep + "@llvm-project//llvm:HexagonCodeGen", # fixdeps: keep ]) + if_llvm_powerpc_available([ "@llvm-project//llvm:PowerPCAsmParser", # fixdeps: keep "@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index d2a4f53a426f09..e558bd67b8ec7d 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -100,7 +100,7 @@ Status CompileXla(xla::CompileOnlyClient* client, compile_result->entry_point = aot_opts.entry_point_name(); compile_result->pointer_size = xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple()); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -196,6 +196,13 @@ static void InitializeTargets() { LLVMInitializeAArch64AsmParser(); LLVMInitializeAArch64AsmPrinter(); #endif +#if TF_LLVM_HEXAGON_AVAILABLE + LLVMInitializeHexagonTarget(); + LLVMInitializeHexagonTargetInfo(); + LLVMInitializeHexagonTargetMC(); + LLVMInitializeHexagonAsmParser(); + LLVMInitializeHexagonAsmPrinter(); +#endif #if TF_LLVM_POWERPC_AVAILABLE LLVMInitializePowerPCTarget(); LLVMInitializePowerPCTargetInfo(); @@ -252,7 +259,7 @@ Status Main(const MainFlags& flags) { nodes.insert(fetch.id().node_name()); } std::cout << absl::StrJoin(nodes, ","); - return OkStatus(); + return absl::OkStatus(); } // Read and initialize the graph. @@ -306,7 +313,7 @@ Status Main(const MainFlags& flags) { TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result, metadata_result, &header)); TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header)); - return OkStatus(); + return absl::OkStatus(); } } // namespace tfcompile diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index d531996cbb2be1..fc21fc99a0b84e 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -26,7 +26,7 @@ limitations under the License. namespace tensorflow { namespace tfcompile { -using xla::StatusOr; +using absl::StatusOr; // Represents a set of protocol buffers embedded into an object file and // describes how to access them at runtime. diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index f85fd5fde4c1fa..235e8fda0dfc86 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -287,7 +287,6 @@ void AllocateAndParseFlags() { bool enable_mlir_multiple_local_cpu_devices = false; // Dump graphs in TFG dialect. bool use_tfg_graph_dumper = false; - bool enable_mlir_generic_outside_compilation = false; bool enable_tpu_variable_runtime_reformatting_pass = true; flag_list = new std::vector( @@ -391,10 +390,6 @@ void AllocateAndParseFlags() { Flag("tf_dump_graphs_in_tfg", &use_tfg_graph_dumper, "When tf_dump_graphs_in_tfg is true, graphs after transformations " "are dumped in MLIR TFG dialect and not in GraphDef"), - Flag("tf_mlir_enable_generic_outside_compilation", - &enable_mlir_generic_outside_compilation, - "Enables OutsideCompilation passes for MLIR-Based TensorFlow " - "Generic Compiler Bridge."), Flag("tf_mlir_enable_tpu_variable_runtime_reformatting_pass", &enable_tpu_variable_runtime_reformatting_pass, "Enables TPUVariableRuntimeReformatting pass for MLIR-Based " @@ -422,8 +417,6 @@ void AllocateAndParseFlags() { mlir_flags->tf_mlir_enable_composite_tpuexecute_side_effects = enable_mlir_composite_tpuexecute_side_effects; mlir_flags->tf_mlir_enable_strict_clusters = enable_mlir_strict_clusters; - mlir_flags->tf_mlir_enable_generic_outside_compilation = - enable_mlir_generic_outside_compilation; mlir_flags->tf_mlir_enable_tpu_variable_runtime_reformatting_pass = enable_tpu_variable_runtime_reformatting_pass; mlir_flags->tf_mlir_enable_multiple_local_cpu_devices = diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index d2c078a617b258..9dbd6106514ab8 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -290,7 +290,6 @@ struct MlirCommonFlags { bool tf_mlir_enable_convert_control_to_data_outputs_pass; bool tf_mlir_enable_composite_tpuexecute_side_effects; bool tf_mlir_enable_strict_clusters; - bool tf_mlir_enable_generic_outside_compilation; bool tf_mlir_enable_tpu_variable_runtime_reformatting_pass; // TODO(pineapplejuice233): Revisit this flag once the performance impact is verified // with different local CPU devices settings. diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 25654267a6ae01..9d75388cfbbe80 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -921,13 +921,13 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { absl::StatusOr> execution_inputs; std::map snapshot_ptrs; { - tensorflow::profiler::TraceMe hlo_module_activity( + tsl::profiler::TraceMe hlo_module_activity( [&] { return absl::StrCat( "Populate Inputs (", closure.compilation_result()->xla_input_shapes.size(), ")"); }, - tensorflow::profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMeLevel::kInfo); for (const auto& [variable_index, variable_tensor] : closure.resource_var_snapshots()) { @@ -957,11 +957,11 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { closure.executable(), ctx, allocator.get()); OP_REQUIRES(ctx, execution_output.ok(), execution_output.status()); - tensorflow::profiler::TraceMe hlo_module_activity( + tsl::profiler::TraceMe hlo_module_activity( [&] { return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")"); }, - tensorflow::profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMeLevel::kInfo); absl::StatusOr> variable_infos = GatherVariableInfo( ctx, *closure.compilation_result(), closure.num_constant_args()); diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index 6372b2e5516cd3..37a8bf9ce39df6 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -11,7 +11,10 @@ package( cc_library( name = "xla_ops", srcs = ["xla_ops.cc"], - deps = ["//tensorflow/core:framework"], + deps = [ + "//tensorflow/core:framework", + "@com_google_absl//absl/status", + ], alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index 7c370e46dec63f..8d49471c33741b 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" +#include "absl/status/status.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" diff --git a/tensorflow/compiler/jit/pjrt_device_context.cc b/tensorflow/compiler/jit/pjrt_device_context.cc index 84a651361681dc..2d982fa82e76b8 100644 --- a/tensorflow/compiler/jit/pjrt_device_context.cc +++ b/tensorflow/compiler/jit/pjrt_device_context.cc @@ -62,7 +62,7 @@ absl::StatusOr> HostTensorToPjRtBuffer( auto first_try_buffer = pjrt_client->BufferFromHostBuffer( cpu_tensor->data(), shape.element_type(), shape.dimensions(), /*byte_strides=*/std::nullopt, - xla::PjRtClient::HostBufferSemantics::kZeroCopy, + xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy, /*on_done_with_host_buffer=*/ [cpu_tensor = *cpu_tensor]() { /* frees tensor */ }, pjrt_device, device_layout); @@ -78,7 +78,7 @@ absl::StatusOr> HostTensorToPjRtBuffer( pjrt_client->BufferFromHostBuffer( cpu_tensor->data(), shape.element_type(), shape.dimensions(), /*byte_strides=*/std::nullopt, - xla::PjRtClient::HostBufferSemantics::kZeroCopy, + xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy, /*on_done_with_host_buffer=*/ [cpu_tensor = *cpu_tensor]() { /* frees tensor */ }, pjrt_device)); return second_try_buffer; @@ -93,7 +93,7 @@ void PjRtDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, Device* device, Tensor* cpu_tensor, StatusCallback done) { - profiler::TraceMe traceme("PjRtDeviceContext::CopyDeviceTensorToCPU"); + tsl::profiler::TraceMe traceme("PjRtDeviceContext::CopyDeviceTensorToCPU"); if (device_tensor->NumElements() == 0) { VLOG(2) << "CopyDeviceTensorToCPU empty tensor"; done(absl::OkStatus()); @@ -136,7 +136,7 @@ void PjRtDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, return; } - xla::PjRtFuture future = device_buffer->ToLiteral(literal.get()); + xla::PjRtFuture<> future = device_buffer->ToLiteral(literal.get()); future.OnReady([literal = std::move(literal), done = std::move(done)]( const tensorflow::Status& status) { done(status); }); } @@ -146,14 +146,13 @@ void PjRtDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Tensor* device_tensor, StatusCallback done, bool sync_dst_compute) const { - profiler::TraceMe traceme("PjRtDeviceContext::CopyCPUTensorToDevice"); + tsl::profiler::TraceMe traceme("PjRtDeviceContext::CopyCPUTensorToDevice"); if (cpu_tensor->NumElements() == 0) { VLOG(2) << "CopyCPUTensorToDevice empty tensor"; done(absl::OkStatus()); return; } - // TODO(b/252887149): figure out how to cache PJRT client. absl::StatusOr pjrt_client = GetOrCreatePjRtClient(DeviceType(device->device_type())); if (!pjrt_client.ok()) { @@ -187,8 +186,6 @@ void PjRtDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, CHECK(!result_tensor->GetBuffer()); // Crash OK result_tensor->SetBuffer(std::move(*buffer_or)); } - // TODO(b/244666476): evaluate the performance impact of marking ready when - // the data in device buffer is computed. pjrt_buffer->GetReadyFuture().OnReady(std::move(done)); } @@ -243,7 +240,7 @@ void PjRtDeviceToDeviceCopy(DeviceContext* send_dev_context, AllocatorAttributes dst_alloc_attr, const Tensor* input, Tensor* output, int dev_to_dev_stream_index, StatusCallback done) { - profiler::TraceMe traceme("PjRtDevice_DeviceToDeviceCopy"); + tsl::profiler::TraceMe traceme("PjRtDevice_DeviceToDeviceCopy"); if (input->NumElements() == 0) { VLOG(2) << "PjRtDevice_DeviceToDeviceCopy empty tensor"; done(absl::OkStatus()); @@ -298,8 +295,6 @@ void PjRtDeviceToDeviceCopy(DeviceContext* send_dev_context, CHECK(!output_tensor->GetBuffer()); // Crash OK output_tensor->SetBuffer(std::move(*buffer_or)); } - // TODO(b/244666476): evaluate the performance impact of marking ready when - // the data in device buffer is computed. pjrt_buffer->GetReadyFuture().OnReady(std::move(done)); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 2d0d9d51036033..b5b0c16422ccab 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -485,7 +485,8 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, Status XlaDevice::Sync() { VLOG(1) << "XlaDevice::Sync"; - profiler::TraceMe activity("XlaDevice::Sync", profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMe activity("XlaDevice::Sync", + tsl::profiler::TraceMeLevel::kInfo); std::shared_ptr stream; { mutex_lock lock(mu_); diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 4eba7373910f97..821d294af90f66 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -261,7 +261,7 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, transfer_manager_->TransferLiteralFromDevice( device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal, [this, ref, xla_tensor, done, device_to_host_stream, - device_allows_sync_on_completion](xla::Status status) { + device_allows_sync_on_completion](absl::Status status) { Status done_status = status; VLOG(2) << "Transfer from device as literal: " << xla_tensor->shaped_buffer().ToString(); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index f9657509623cc1..9107e07b83bc21 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -920,7 +920,7 @@ absl::StatusOr>> RunPjRtExecutable( &executable_args, &owned_executable_args, &non_donatable_input_indices)); std::vector> execute_outputs; - std::optional> future; + std::optional> future; if (executable->num_replicas() != 1 || executable->num_partitions() != 1) { TF_ASSIGN_OR_RETURN( execute_outputs, diff --git a/tensorflow/compiler/mlir/lite/debug/debug_test.cc b/tensorflow/compiler/mlir/lite/debug/debug_test.cc index 127d485b842f94..d876e05b31fdf6 100644 --- a/tensorflow/compiler/mlir/lite/debug/debug_test.cc +++ b/tensorflow/compiler/mlir/lite/debug/debug_test.cc @@ -120,7 +120,7 @@ class InitPassManagerTest : public testing::Test { builder.create(builder.getUnknownLoc()); } - tsl::Status GetDumpDir(std::string* dump_dir) { + absl::Status GetDumpDir(std::string* dump_dir) { std::vector files; if (auto status = tsl::Env::Default()->GetChildren(path_, &files); !status.ok()) { @@ -131,7 +131,7 @@ class InitPassManagerTest : public testing::Test { "Expecting directory to have one child."); } *dump_dir = tsl::io::JoinPath(path_, files[0]); - return tsl::OkStatus(); + return absl::OkStatus(); } std::string path_; diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index b98d3220ee15a8..4655fa7c069b54 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -120,6 +120,7 @@ limitations under the License. #include "tsl/platform/status.h" #include "tsl/platform/tstring.h" +using absl::StatusOr; using llvm::dyn_cast; using llvm::formatv; using llvm::isa; @@ -143,7 +144,6 @@ using tensorflow::OpOrArgLocNameMapper; using tensorflow::OpOrArgNameMapper; using tensorflow::Status; using tflite::flex::IsAllowlistedFlexOp; -using xla::StatusOr; template using BufferOffset = flatbuffers::Offset; diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 90a6e7704d70af..0d477b51b6d467 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -101,6 +101,8 @@ limitations under the License. #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" +using absl::Status; +using absl::StatusOr; using llvm::ArrayRef; using mlir::Builder; using mlir::DenseElementsAttr; @@ -115,8 +117,6 @@ using mlir::Value; using mlir::func::FuncOp; using tflite::OperatorT; using tflite::TensorT; -using xla::Status; -using xla::StatusOr; namespace errors = tensorflow::errors; namespace tfl = mlir::TFL; @@ -519,7 +519,7 @@ Status ConvertSubgraphIdxToStablehloRegion( op_state.addAttribute("body", body_attr); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (auto* opts = op.builtin_options_2.AsStablehloReduceWindowOptions()) { int32_t body_idx = opts->body_subgraph_index; @@ -532,7 +532,7 @@ Status ConvertSubgraphIdxToStablehloRegion( op_state.addAttribute("body", body_attr); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (auto* opts = op.builtin_options_2.AsStablehloSortOptions()) { int32_t comparator_idx = opts->comparator_subgraph_index; @@ -545,7 +545,7 @@ Status ConvertSubgraphIdxToStablehloRegion( op_state.addAttribute("comparator", comparator_attr); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (auto* opts = op.builtin_options_2.AsStablehloWhileOptions()) { int32_t body_idx = opts->body_subgraph_index; @@ -566,7 +566,7 @@ Status ConvertSubgraphIdxToStablehloRegion( op_state.addAttribute("body", body_attr); op_state.addAttribute("cond", cond_attr); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (auto* opts = op.builtin_options_2.AsStablehloScatterOptions()) { uint32_t subgraph_idx = opts->update_computation_subgraph_index; @@ -580,10 +580,10 @@ Status ConvertSubgraphIdxToStablehloRegion( op_state.addAttribute(kScatterRegionFuncName, subgraph_attr); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // skip if not supported - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } Status AddOpIntermediatesForLstm( @@ -612,7 +612,7 @@ Status AddOpIntermediatesForLstm( op_state.addAttribute(named_attr.getName(), named_attr.getValue()); } } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // TODO(krzysd) Handle function calls @@ -747,7 +747,7 @@ StatusOr ConvertOp( llvm::SmallVector attrs; auto builtin_code = tflite::GetBuiltinCode(&op_code); if (builtin_code == tflite::BuiltinOperator_CUSTOM) { - auto status = ::tensorflow::OkStatus(); + auto status = absl::OkStatus(); std::vector custom_options; diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index 3166b589418658..f72ef1f9641d48 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -58,9 +58,9 @@ limitations under the License. namespace { +using ::absl::StatusOr; using ::tensorflow::Status; using ::tensorflow::errors::InvalidArgument; -using ::xla::StatusOr; StatusOr GetPaddingAttr(TfLitePadding pad_params, mlir::Builder builder, @@ -448,7 +448,7 @@ Status mlir::CustomOptionsToAttributes( "custom_option", mlir::TFL::ConstBytesAttr::get(builder.getContext(), content))); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // TODO(zichuanwei@): Populate Builtin_options_2 manual for now, should diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 30a19e04fb368a..8ac81939d0d4de 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -578,6 +578,14 @@ inline bool IsBF16ShapedType(Type t) { return false; } +// Returns true if it is a shaped type of FloatType elements. +inline bool IsFloatShapedType(Type t) { + if (auto shaped_type = t.dyn_cast_or_null()) { + return shaped_type.getElementType().isa(); + } + return false; +} + // Returns new shape with rank 'new_dims' with padded ones on the // left if needed. inline std::vector GetPaddedShape(ArrayRef old_shape, @@ -3069,6 +3077,50 @@ OpFoldResult SquareOp::fold(FoldAdaptor adaptor) { return ConstFoldUnaryOp(result_type, operands[0], compute); } +//===----------------------------------------------------------------------===// +// MaximumOp +//===----------------------------------------------------------------------===// + +OpFoldResult MaximumOp::fold(FoldAdaptor adaptor) { + auto lhs_type = getLhs().getType(); + auto rhs_type = getRhs().getType(); + // Only constant fold for float tensors of the same type is implemented. + if (lhs_type != rhs_type || !IsFloatShapedType(lhs_type)) return nullptr; + + auto lhs = adaptor.getLhs().dyn_cast_or_null(); + auto rhs = adaptor.getRhs().dyn_cast_or_null(); + if (lhs && lhs.isSplat()) { + APFloat lhs_value = lhs.getSplatValue(); + lhs_value.changeSign(); + if (lhs_value.isLargest()) return getRhs(); + } + if (rhs && rhs.isSplat()) { + APFloat rhs_value = rhs.getSplatValue(); + rhs_value.changeSign(); + if (rhs_value.isLargest()) return getLhs(); + } + return nullptr; +} + +//===----------------------------------------------------------------------===// +// MinimumOp +//===----------------------------------------------------------------------===// + +OpFoldResult MinimumOp::fold(FoldAdaptor adaptor) { + auto lhs_type = getLhs().getType(); + auto rhs_type = getRhs().getType(); + // Only constant fold for float tensors of the same type is implemented. + if (lhs_type != rhs_type || !IsFloatShapedType(lhs_type)) return nullptr; + + auto lhs = adaptor.getLhs().dyn_cast_or_null(); + auto rhs = adaptor.getRhs().dyn_cast_or_null(); + if (lhs && lhs.isSplat() && lhs.getSplatValue().isLargest()) + return getRhs(); + if (rhs && rhs.isSplat() && rhs.getSplatValue().isLargest()) + return getLhs(); + return nullptr; +} + //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 481f5573058b8c..5f4cce6d8e8a76 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -2269,6 +2269,8 @@ def TFL_MaximumOp : TFL_Op<"maximum", [ TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$max ); + let hasFolder = 1; + let builders = [TFL_BroadcastableBinaryBuilder]; let hasOptions = 0; @@ -2528,6 +2530,8 @@ def TFL_MinimumOp : TFL_Op<"minimum", [ TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$min ); + let hasFolder = 1; + let builders = [TFL_BroadcastableBinaryBuilder]; let hasOptions = 0; diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 3e50192fa0640d..16e12bbb6da04d 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -124,7 +124,7 @@ Status HandleInputOutputArraysWithModule( ") does not exist in the given graph"); } } - return OkStatus(); + return absl::OkStatus(); } Status ConvertSavedModelToTFLiteFlatBuffer( diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD index c6928f60f1ccaa..f6d8de698481e3 100644 --- a/tensorflow/compiler/mlir/lite/quantization/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -112,7 +112,6 @@ cc_library( ":device_target", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", - "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/quantization/numerical_utils_test.cc b/tensorflow/compiler/mlir/lite/quantization/numerical_utils_test.cc index e0035e5c3c5175..7f9b02b9f61473 100644 --- a/tensorflow/compiler/mlir/lite/quantization/numerical_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/numerical_utils_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include "absl/types/optional.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc index d7bcc43b208675..ecba20595c0f91 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc @@ -18,9 +18,7 @@ limitations under the License. #include #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" @@ -28,10 +26,9 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h index c67573d0e9d3f8..8a6bf4f83a28c5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD index f96d4961e733b4..43142a7a7c52dd 100644 --- a/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD @@ -20,6 +20,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:config", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:weight_only_ptq", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_freeze_variables", diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc index 08f5ecd4851b7e..9561b9003add3b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" @@ -43,6 +44,8 @@ namespace tensorflow { namespace { using ::mlir::quant::stablehlo::StaticRangePtqComponent; +using ::mlir::quant::stablehlo::WeightOnlyPtqComponent; +using ::stablehlo::quantization::Method; using ::stablehlo::quantization::PopulateDefaults; using ::stablehlo::quantization::QuantizationConfig; using ::tensorflow::SignatureDef; @@ -122,22 +125,38 @@ absl::StatusOr RunQuantization( // after variable freezing. mlir::PassManager pm(module_op.getContext()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); - mlir::odml::AddLegalizeTFToStablehloPasses( - pm, /*skip_quantization_ops=*/true, - /*skip_resize=*/false, /*skip_stateful_partitioned_call=*/false); + mlir::odml::AddLegalizeTFToStablehloPasses(pm, /*skip_quantization_ops=*/true, + /*skip_resize=*/false, + /*skip_partitioned_calls=*/false); pm.addNestedPass( mlir::quant::stablehlo::createRemoveShardingCustomCallPass()); if (failed(pm.run(module_op))) { return absl::InternalError("Failed to run legalize TF to StableHLO."); } - StaticRangePtqComponent static_range_ptq_component( - module_op.getContext(), quantization_py_function_lib, saved_model_dir, - /*signature_keys=*/exported_names, saved_model_tags, signature_def_map, - GetFunctionAliases(*saved_model_bundle)); + absl::StatusOr quantized_module_op; + // Currently, only StaticRangePtq or WeightOnlyPtq is supported. + // Consider merging the pipelines to address mixed algorithm models. + if (HasQuantizationMethod(updated_config.specs(), + Method::MethodCase::kStaticRangePtq)) { + StaticRangePtqComponent static_range_ptq_component( + module_op.getContext(), quantization_py_function_lib, saved_model_dir, + /*signature_keys=*/exported_names, saved_model_tags, signature_def_map, + GetFunctionAliases(*saved_model_bundle)); + + quantized_module_op = + static_range_ptq_component.Run(module_op, updated_config); + } else if (HasQuantizationMethod(updated_config.specs(), + Method::MethodCase::kWeightOnlyPtq)) { + WeightOnlyPtqComponent weight_only_ptq_component(module_op.getContext()); + quantized_module_op = + weight_only_ptq_component.Run(module_op, updated_config); + } else { + return absl::InvalidArgumentError( + "Quantization config must have either static_range_ptq_preset or " + "weight_only_ptq_preset."); + } - absl::StatusOr quantized_module_op = - static_range_ptq_component.Run(module_op, updated_config); if (!quantized_module_op.ok()) { return absl::InternalError("Failed to run quantization. Status msg: " + quantized_module_op.status().ToString()); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index bd83f16de105f8..9976d6ff363c8f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -1,4 +1,4 @@ -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") @@ -693,12 +693,41 @@ cc_library( ], ) +td_library( + name = "composite_td_files", + srcs = [ + "transforms/composite_avg_pool_patterns.td", + "transforms/composite_utils.td", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", + "@llvm-project//mlir:FuncTdFiles", + "@local_xla//xla/mlir_hlo:hlo_ops_td_files", + ], +) + +cc_library( + name = "composite_utils", + srcs = ["transforms/composite_utils.cc"], + hdrs = ["transforms/composite_utils.h"], + deps = [ + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla/mlir_hlo", + ], +) + cc_library( name = "composite_lowering", srcs = [ + "transforms/composite_avg_pool.cc", "transforms/composite_lowering_pass.cc", ], hdrs = [ + "transforms/composite_avg_pool.h", "transforms/passes.h", ], copts = [ @@ -706,8 +735,13 @@ cc_library( ], deps = [ ":composite_lowering_inc_gen", + ":composite_utils", ":passes_inc_gen", "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/core:framework", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -730,6 +764,8 @@ gentbl_cc_library( tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/composite_lowering_patterns.td", deps = [ + ":composite_td_files", + ":composite_utils", "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncTdFiles", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD new file mode 100644 index 00000000000000..c487600517f9b8 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD @@ -0,0 +1,64 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + ":friends", + ], +) + +package_group( + name = "friends", + packages = [ + "//tensorflow/compiler/mlir/lite/...", + ], +) + +tf_cc_binary( + name = "odml-converter", + srcs = ["odml_converter_main.cc"], + visibility = [ + "//tensorflow/compiler/mlir/lite/stablehlo/odml_converter:__subpackages__", + "//third_party/odml/infra:__subpackages__", + ], # Prototype phase. + deps = [ + ":all_passes", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_to_vhlo_pass", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", + ], +) + +gentbl_cc_library( + name = "passes_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=ODMLConverter", + ], + "passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +) + +cc_library( + name = "all_passes", + hdrs = ["passes.h"], + deps = [":passes_inc_gen"], +) + +exports_files([ + "run_lit.sh", +]) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/odml_converter_main.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/odml_converter_main.cc new file mode 100644 index 00000000000000..ecd7396c2a4622 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/odml_converter_main.cc @@ -0,0 +1,49 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +const char* art = R"( + ___ ___ __ __ _ ___ _ + / _ \| \| \/ | | / __|___ _ ___ _____ _ _| |_ ___ _ _ +| (_) | |) | |\/| | |__ | (__/ _ \ ' \ V / -_) '_| _/ -_) '_| + \___/|___/|_| |_|____| \___\___/_||_\_/\___|_| \__\___|_| +)"; + +int main(int argc, char* argv[]) { + tensorflow::InitMlir y(&argc, &argv); + llvm::errs() << art << "\n"; + + mlir::odml::registerODMLConverterPasses(); + mlir::odml::registerLegalizeStablehloToVhloPass(); + + mlir::DialectRegistry registry; + registry.insert(); + + return failed( + mlir::MlirOptMain(argc, argv, "ODML Converter Driver\n", registry)); +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h new file mode 100644 index 00000000000000..b3589356f196a2 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h @@ -0,0 +1,26 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_PASSES_H_ + +namespace mlir::odml { + +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h.inc" + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_PASSES_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.td b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.td new file mode 100644 index 00000000000000..800d7e0d2ff59b --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.td @@ -0,0 +1,17 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/tests/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/tests/BUILD new file mode 100644 index 00000000000000..c990b20c8fb51c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/tests/BUILD @@ -0,0 +1,25 @@ +load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/compiler/mlir/lite/stablehlo/odml_converter:__subpackages__"], +) + +glob_lit_tests( + name = "filecheck_tests", + data = [":test_utilities"], + driver = "@llvm-project//mlir:run_lit.sh", + test_file_exts = [ + "mlir", + ], +) + +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/mlir/lite/stablehlo/odml_converter:odml-converter", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir index 5924d0dce396c4..c614ee10bf2b45 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir @@ -8,15 +8,15 @@ func.func @hardswish(%arg0: tensor<2xf32>) -> (tensor<*xf32>) { } func.func private @XlaCallModule_aten.hardswish.default.impl_0(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = mhlo.constant dense<6.000000e+00> : tensor - %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<2xf32> %2 = mhlo.constant dense<3.40282347E+38> : tensor - %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<2xf32> %4 = mhlo.constant dense<3.000000e+00> : tensor - %5 = "mhlo.broadcast_in_dim"(%4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %5 = "mhlo.broadcast_in_dim"(%4) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<2xf32> %6 = mhlo.constant dense<0.000000e+00> : tensor - %7 = "mhlo.broadcast_in_dim"(%6) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %7 = "mhlo.broadcast_in_dim"(%6) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<2xf32> %8 = mhlo.constant dense<-3.40282347E+38> : tensor - %9 = "mhlo.broadcast_in_dim"(%8) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %9 = "mhlo.broadcast_in_dim"(%8) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<2xf32> %10 = mhlo.add %arg0, %5 : tensor<2xf32> %11 = mhlo.clamp %7, %10, %3 : tensor<2xf32> %12 = mhlo.clamp %9, %11, %1 : tensor<2xf32> @@ -31,4 +31,149 @@ func.func private @XlaCallModule_aten.hardswish.default.impl_0(%arg0: tensor<2xf // CHECK: %[[VAL_2:.*]] = "tf.Identity"(%[[VAL_1]]) {device = ""} : (tensor<2xf32>) -> tensor<*xf32> // CHECK: %[[VAL_3:.*]] = "tf.Identity"(%[[VAL_2]]) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> // CHECK: return %[[VAL_3]] : tensor<*xf32> -// CHECK: } \ No newline at end of file +// CHECK: } + + +func.func @avg_pool2d_1(%arg0: tensor<1x3x6x6xf32>) -> (tensor<*xf32>) { + %0 = mhlo.composite "aten.avg_pool2d.default" %arg0 {composite_attributes = {ceil_mode = false, count_include_pad = true, divisor_override = "py_None", kernel_size = dense<3> : tensor<2xi64>, padding = dense<0> : tensor<2xi64>, stride = dense<1> : tensor<2xi64>}, decomposition = @XlaCallModule_aten.avg_pool2d.default.impl_0} : (tensor<1x3x6x6xf32>) -> tensor<1x3x4x4xf32> + %1 = "tf.Identity"(%0) {device = ""} : (tensor<1x3x4x4xf32>) -> tensor<*xf32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> +} +func.func private @XlaCallModule_aten.avg_pool2d.default.impl_0(%arg0: tensor<1x3x6x6xf32>) -> tensor<1x3x4x4xf32> { + %0 = mhlo.constant dense<1.000000e+00> : tensor + %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<6x6xf32> + %2 = mhlo.constant dense<0.000000e+00> : tensor + %3 = "mhlo.reduce_window"(%arg0, %2) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = mhlo.add %arg1, %arg2 : tensor + mhlo.return %7 : tensor + }) {window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>} : (tensor<1x3x6x6xf32>, tensor) -> tensor<1x3x4x4xf32> + %4 = "mhlo.reduce_window"(%1, %2) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = mhlo.add %arg1, %arg2 : tensor + mhlo.return %7 : tensor + }) {window_dimensions = dense<3> : tensor<2xi64>} : (tensor<6x6xf32>, tensor) -> tensor<4x4xf32> + %5 = "mhlo.broadcast_in_dim"(%4) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<1x3x4x4xf32> + %6 = mhlo.divide %3, %5 : tensor<1x3x4x4xf32> + return %6 : tensor<1x3x4x4xf32> +} + +// CHECK-LABEL: func.func @avg_pool2d_1( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x6x6xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32> +// CHECK: %[[VAL_2:.*]] = "tfl.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x3x6x6xf32>, tensor<4xi32>) -> tensor<1x6x6x3xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : tensor<4x2xi32> +// CHECK: %[[VAL_4:.*]] = "tfl.pad"(%[[VAL_2]], %[[VAL_3]]) : (tensor<1x6x6x3xf32>, tensor<4x2xi32>) -> tensor<1x6x6x3xf32> +// CHECK: %[[VAL_5:.*]] = "tfl.average_pool_2d"(%[[VAL_4]]) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x6x6x3xf32>) -> tensor<1x4x4x3xf32> +// CHECK: %[[VAL_6:.*]] = arith.constant dense<[0, 3, 1, 2]> : tensor<4xi32> +// CHECK: %[[VAL_7:.*]] = "tfl.transpose"(%[[VAL_5]], %[[VAL_6]]) : (tensor<1x4x4x3xf32>, tensor<4xi32>) -> tensor<1x3x4x4xf32> +// CHECK: %[[VAL_8:.*]] = "tf.Identity"(%[[VAL_7]]) {device = ""} : (tensor<1x3x4x4xf32>) -> tensor<*xf32> +// CHECK: %[[VAL_9:.*]] = "tf.Identity"(%[[VAL_8]]) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_9]] : tensor<*xf32> +// CHECK: } + +func.func @avg_pool2d_2(%arg0: tensor<1x3x6x6xf32>) -> (tensor<*xf32>) { + %0 = mhlo.composite "aten.avg_pool2d.default" %arg0 {composite_attributes = {ceil_mode = false, count_include_pad = false, divisor_override = "py_None", kernel_size = dense<3> : tensor<2xi64>, padding = dense<1> : tensor<2xi64>, stride = dense<1> : tensor<2xi64>}, decomposition = @XlaCallModule_aten.avg_pool2d.default.impl_1} : (tensor<1x3x6x6xf32>) -> tensor<1x3x6x6xf32> + %1 = "tf.Identity"(%0) {device = ""} : (tensor<1x3x6x6xf32>) -> tensor<*xf32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> +} +func.func private @XlaCallModule_aten.avg_pool2d.default.impl_1(%arg0: tensor<1x3x6x6xf32>) -> tensor<1x3x6x6xf32> { + %0 = mhlo.constant dense<[[0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]]> : tensor<8x8xf32> + %1 = mhlo.constant dense<0.000000e+00> : tensor + %2 = "mhlo.pad"(%arg0, %1) {edge_padding_high = dense<[0, 0, 1, 1]> : tensor<4xi64>, edge_padding_low = dense<[0, 0, 1, 1]> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x3x6x6xf32>, tensor) -> tensor<1x3x8x8xf32> + %3 = "mhlo.reduce_window"(%2, %1) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = mhlo.add %arg1, %arg2 : tensor + mhlo.return %7 : tensor + }) {window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>} : (tensor<1x3x8x8xf32>, tensor) -> tensor<1x3x6x6xf32> + %4 = "mhlo.reduce_window"(%0, %1) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = mhlo.add %arg1, %arg2 : tensor + mhlo.return %7 : tensor + }) {window_dimensions = dense<3> : tensor<2xi64>} : (tensor<8x8xf32>, tensor) -> tensor<6x6xf32> + %5 = "mhlo.broadcast_in_dim"(%4) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<6x6xf32>) -> tensor<1x3x6x6xf32> + %6 = mhlo.divide %3, %5 : tensor<1x3x6x6xf32> + return %6 : tensor<1x3x6x6xf32> +} + +// CHECK-LABEL: func.func @avg_pool2d_2( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x6x6xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32> +// CHECK: %[[VAL_2:.*]] = "tfl.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x3x6x6xf32>, tensor<4xi32>) -> tensor<1x6x6x3xf32> +// CHECK: %[[VAL_3:.*]] = "tfl.average_pool_2d"(%[[VAL_2]]) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x6x6x3xf32>) -> tensor<1x6x6x3xf32> +// CHECK: %[[VAL_4:.*]] = arith.constant dense<[0, 3, 1, 2]> : tensor<4xi32> +// CHECK: %[[VAL_5:.*]] = "tfl.transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<1x6x6x3xf32>, tensor<4xi32>) -> tensor<1x3x6x6xf32> +// CHECK: %[[VAL_6:.*]] = "tf.Identity"(%[[VAL_5]]) {device = ""} : (tensor<1x3x6x6xf32>) -> tensor<*xf32> +// CHECK: %[[VAL_7:.*]] = "tf.Identity"(%[[VAL_6]]) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_7]] : tensor<*xf32> +// CHECK: } + +func.func @upsample_bilinear2d(%arg0: tensor<1x64x16x16xf32>) -> (tensor<1x64x32x32xf32>) { + %0 = mhlo.composite "odml.upsample_bilinear2d" %arg0 {composite_attributes = {align_corners = false, output = dense<32> : tensor<2xi64>}, decomposition = @XlaCallModule_odml.upsample_bilinear2d.impl_21_0} : (tensor<1x64x16x16xf32>) -> tensor<1x64x32x32xf32> + return %0 : tensor<1x64x32x32xf32> +} +func.func private @XlaCallModule_odml.upsample_bilinear2d.impl_21_0(%arg0: tensor<1x64x16x16xf32>) -> tensor<1x64x32x32xf32> { + %0 = mhlo.constant dense<[[0.000000e+00], [2.500000e-01], [7.500000e-01], [2.500000e-01], [7.500000e-01], [2.500000e-01], [7.500000e-01], [2.500000e-01], [7.500000e-01], [2.500000e-01], [7.500000e-01], [2.500000e-01], [7.500000e-01], [2.500000e-01], [7.500000e-01], [2.500000e-01], [7.500000e-01], [2.500000e-01], [7.500000e-01], [2.500000e-01], [7.500000e-01], [2.500000e-01], [7.500000e-01], [2.500000e-01], [7.500000e-01], [2.500000e-01], [7.500000e-01], [2.500000e-01], [7.500000e-01], [2.500000e-01], [7.500000e-01], [2.500000e-01]]> : tensor<32x1xf32> + %1 = mhlo.constant dense<[0.000000e+00, 2.500000e-01, 7.500000e-01, 2.500000e-01, 7.500000e-01, 2.500000e-01, 7.500000e-01, 2.500000e-01, 7.500000e-01, 2.500000e-01, 7.500000e-01, 2.500000e-01, 7.500000e-01, 2.500000e-01, 7.500000e-01, 2.500000e-01, 7.500000e-01, 2.500000e-01, 7.500000e-01, 2.500000e-01, 7.500000e-01, 2.500000e-01, 7.500000e-01, 2.500000e-01, 7.500000e-01, 2.500000e-01, 7.500000e-01, 2.500000e-01, 7.500000e-01, 2.500000e-01, 7.500000e-01, 2.500000e-01]> : tensor<32xf32> + %2 = mhlo.constant dense<[1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 15]> : tensor<32xi64> + %3 = mhlo.constant dense<16> : tensor + %4 = "mhlo.broadcast_in_dim"(%3) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<32x32xi64> + %5 = mhlo.constant dense<0> : tensor + %6 = "mhlo.broadcast_in_dim"(%5) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<32x32xi64> + %7 = mhlo.constant dense<[0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15]> : tensor<32xi64> + %8 = "mhlo.broadcast_in_dim"(%7) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<32xi64>) -> tensor<32x32xi64> + %9 = mhlo.compare LT, %8, %6 : (tensor<32x32xi64>, tensor<32x32xi64>) -> tensor<32x32xi1> + %10 = mhlo.add %8, %4 : tensor<32x32xi64> + %11 = mhlo.select %9, %10, %8 : tensor<32x32xi1>, tensor<32x32xi64> + %12 = mhlo.reshape %11 : (tensor<32x32xi64>) -> tensor<32x32x1xi64> + %13 = "mhlo.broadcast_in_dim"(%7) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<32xi64>) -> tensor<32x32xi64> + %14 = mhlo.compare LT, %13, %6 : (tensor<32x32xi64>, tensor<32x32xi64>) -> tensor<32x32xi1> + %15 = mhlo.add %13, %4 : tensor<32x32xi64> + %16 = mhlo.select %14, %15, %13 : tensor<32x32xi1>, tensor<32x32xi64> + %17 = mhlo.reshape %16 : (tensor<32x32xi64>) -> tensor<32x32x1xi64> + %18 = "mhlo.concatenate"(%12, %17) <{dimension = 2 : i64}> : (tensor<32x32x1xi64>, tensor<32x32x1xi64>) -> tensor<32x32x2xi64> + %19 = "mhlo.gather"(%arg0, %18) <{dimension_numbers = #mhlo.gather, slice_sizes = dense<[1, 64, 1, 1]> : tensor<4xi64>}> : (tensor<1x64x16x16xf32>, tensor<32x32x2xi64>) -> tensor<1x64x32x32xf32> + %20 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<32xi64>) -> tensor<32x32xi64> + %21 = mhlo.compare LT, %20, %6 : (tensor<32x32xi64>, tensor<32x32xi64>) -> tensor<32x32xi1> + %22 = mhlo.add %20, %4 : tensor<32x32xi64> + %23 = mhlo.select %21, %22, %20 : tensor<32x32xi1>, tensor<32x32xi64> + %24 = mhlo.reshape %23 : (tensor<32x32xi64>) -> tensor<32x32x1xi64> + %25 = "mhlo.concatenate"(%12, %24) <{dimension = 2 : i64}> : (tensor<32x32x1xi64>, tensor<32x32x1xi64>) -> tensor<32x32x2xi64> + %26 = "mhlo.gather"(%arg0, %25) <{dimension_numbers = #mhlo.gather, slice_sizes = dense<[1, 64, 1, 1]> : tensor<4xi64>}> : (tensor<1x64x16x16xf32>, tensor<32x32x2xi64>) -> tensor<1x64x32x32xf32> + %27 = mhlo.subtract %26, %19 : tensor<1x64x32x32xf32> + %28 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<32xf32>) -> tensor<1x64x32x32xf32> + %29 = mhlo.multiply %27, %28 : tensor<1x64x32x32xf32> + %30 = mhlo.add %19, %29 : tensor<1x64x32x32xf32> + %31 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<32xi64>) -> tensor<32x32xi64> + %32 = mhlo.compare LT, %31, %6 : (tensor<32x32xi64>, tensor<32x32xi64>) -> tensor<32x32xi1> + %33 = mhlo.add %31, %4 : tensor<32x32xi64> + %34 = mhlo.select %32, %33, %31 : tensor<32x32xi1>, tensor<32x32xi64> + %35 = mhlo.reshape %34 : (tensor<32x32xi64>) -> tensor<32x32x1xi64> + %36 = "mhlo.concatenate"(%35, %17) <{dimension = 2 : i64}> : (tensor<32x32x1xi64>, tensor<32x32x1xi64>) -> tensor<32x32x2xi64> + %37 = "mhlo.gather"(%arg0, %36) <{dimension_numbers = #mhlo.gather, slice_sizes = dense<[1, 64, 1, 1]> : tensor<4xi64>}> : (tensor<1x64x16x16xf32>, tensor<32x32x2xi64>) -> tensor<1x64x32x32xf32> + %38 = "mhlo.concatenate"(%35, %24) <{dimension = 2 : i64}> : (tensor<32x32x1xi64>, tensor<32x32x1xi64>) -> tensor<32x32x2xi64> + %39 = "mhlo.gather"(%arg0, %38) <{dimension_numbers = #mhlo.gather, slice_sizes = dense<[1, 64, 1, 1]> : tensor<4xi64>}> : (tensor<1x64x16x16xf32>, tensor<32x32x2xi64>) -> tensor<1x64x32x32xf32> + %40 = mhlo.subtract %39, %37 : tensor<1x64x32x32xf32> + %41 = mhlo.multiply %40, %28 : tensor<1x64x32x32xf32> + %42 = mhlo.add %37, %41 : tensor<1x64x32x32xf32> + %43 = mhlo.subtract %42, %30 : tensor<1x64x32x32xf32> + %44 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>}> : (tensor<32x1xf32>) -> tensor<1x64x32x1xf32> + %45 = mhlo.reshape %44 : (tensor<1x64x32x1xf32>) -> tensor<1x64x32xf32> + %46 = "mhlo.broadcast_in_dim"(%45) <{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> : (tensor<1x64x32xf32>) -> tensor<1x64x32x32xf32> + %47 = mhlo.multiply %43, %46 : tensor<1x64x32x32xf32> + %48 = mhlo.add %30, %47 : tensor<1x64x32x32xf32> + return %48 : tensor<1x64x32x32xf32> +} + +// CHECK-LABEL: func.func @upsample_bilinear2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x64x16x16xf32>) -> tensor<1x64x32x32xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32> +// CHECK: %[[VAL_2:.*]] = "tfl.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x64x16x16xf32>, tensor<4xi32>) -> tensor<1x16x16x64xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<32> : tensor<2xi32> +// CHECK: %[[VAL_4:.*]] = "tfl.resize_bilinear"(%[[VAL_2]], %[[VAL_3]]) {align_corners = false, half_pixel_centers = true} : (tensor<1x16x16x64xf32>, tensor<2xi32>) -> tensor<1x32x32x64xf32> +// CHECK: %[[VAL_5:.*]] = arith.constant dense<[0, 3, 1, 2]> : tensor<4xi32> +// CHECK: %[[VAL_6:.*]] = "tfl.transpose"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x32x32x64xf32>, tensor<4xi32>) -> tensor<1x64x32x32xf32> +// CHECK: return %[[VAL_6]] : tensor<1x64x32x32xf32> +// CHECK: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcast.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcast.mlir index 1ec3e3b1fa9fbe..de0cba2f56d258 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcast.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcast.mlir @@ -5,7 +5,7 @@ func.func @foldBroadcastInDimBeforeMulOp_bcast_dim_1D_float() -> (tensor<1x1x2x4 // CHECK-DAG: %[[RES:.*]] = mhlo.constant dense<{{\[\[\[\[}}1.000000e+00, 4.000000e+00, 9.000000e+00, 1.600000e+01], [5.000000e+00, 1.200000e+01, 2.100000e+01, 3.200000e+01]]]]> : tensor<1x1x2x4xf32> %cst0 = mhlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> %cst1 = mhlo.constant dense<[[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]]]> : tensor<1x1x2x4xf32> - %0 = "mhlo.broadcast_in_dim"(%cst0) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x1x2x4xf32> + %0 = "mhlo.broadcast_in_dim"(%cst0) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<1x1x2x4xf32> %1 = mhlo.multiply %0, %cst1 : tensor<1x1x2x4xf32> // CHECK: return %[[RES]] : tensor<1x1x2x4xf32> func.return %1 : tensor<1x1x2x4xf32> @@ -16,7 +16,7 @@ func.func @foldBroadcastInDimBeforeMulOp_bcast_dim_2D_float() -> (tensor<1x2x2x3 // CHECK-DAG: %[[RES:.*]] = mhlo.constant dense<{{\[\[\[\[}}1.000000e+00, 4.000000e+00, 9.000000e+00], [4.000000e+00, 1.000000e+01, 1.800000e+01]], {{\[\[}}2.800000e+01, 4.000000e+01, 5.400000e+01], [4.000000e+01, 5.500000e+01, 7.200000e+01]]]]> : tensor<1x2x2x3xf32> %cst0 = mhlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> %cst1 = mhlo.constant dense<[[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]]> : tensor<1x2x2x3xf32> - %0 = "mhlo.broadcast_in_dim"(%cst0) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<1x2x2x3xf32> + %0 = "mhlo.broadcast_in_dim"(%cst0) <{broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>}> : (tensor<2x3xf32>) -> tensor<1x2x2x3xf32> %1 = mhlo.multiply %0, %cst1 : tensor<1x2x2x3xf32> // CHECK: return %[[RES]] : tensor<1x2x2x3xf32> func.return %1 : tensor<1x2x2x3xf32> @@ -27,7 +27,7 @@ func.func @foldBroadcastInDimBeforeMulOp_bcast_dim_1D_int() -> (tensor<1x1x2x4xi // CHECK-DAG: %[[RES:.*]] = mhlo.constant dense<{{\[\[\[\[}}1, 4, 9, 16], [5, 12, 21, 32]]]]> : tensor<1x1x2x4xi32> %cst0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32> %cst1 = mhlo.constant dense<[[[[1, 2, 3, 4], [5, 6, 7, 8]]]]> : tensor<1x1x2x4xi32> - %0 = "mhlo.broadcast_in_dim"(%cst0) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x1x2x4xi32> + %0 = "mhlo.broadcast_in_dim"(%cst0) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<4xi32>) -> tensor<1x1x2x4xi32> %1 = mhlo.multiply %0, %cst1 : tensor<1x1x2x4xi32> // CHECK: return %[[RES]] : tensor<1x1x2x4xi32> func.return %1 : tensor<1x1x2x4xi32> @@ -38,7 +38,7 @@ func.func @foldBroadcastInDimBeforeMulOp_bcast_dim_4D_int() -> tensor<1x2x1x4xi3 // CHECK-DAG: %[[RES:.*]] = mhlo.constant dense<{{\[\[\[\[}}0, 1, 4, 9]], {{\[\[}}0, 1, 4, 9]]]]> : tensor<1x2x1x4xi32> %0 = mhlo.constant dense<[[[[0, 1, 2, 3]]]]> : tensor<1x1x1x4xi32> %1 = mhlo.constant dense<[[[[0, 1, 2, 3]], [[0, 1, 2, 3]]]]> : tensor<1x2x1x4xi32> - %2 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x1x4xi32>) -> tensor<1x2x1x4xi32> + %2 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<1x1x1x4xi32>) -> tensor<1x2x1x4xi32> %3 = mhlo.multiply %1, %2 : tensor<1x2x1x4xi32> // CHECK: return %[[RES]] : tensor<1x2x1x4xi32> return %3 : tensor<1x2x1x4xi32> @@ -48,8 +48,8 @@ func.func @foldBroadcastInDimBeforeMulOp_bcast_dim_4D_int() -> tensor<1x2x1x4xi3 func.func @notFoldBroadcastInDimBeforeMulOpWhenArgIsNonConst_bcast_dim_1D_int(%arg0: tensor<1x1x2x4xi32>) -> (tensor<1x1x2x4xi32>) { // CHECK-DAG: %[[CONST:.*]] = mhlo.constant dense<{{\[}}1, 2, 3, 4]> : tensor<4xi32> %cst0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32> - // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%[[CONST]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x1x2x4xi32> - %0 = "mhlo.broadcast_in_dim"(%cst0) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x1x2x4xi32> + // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%[[CONST]]) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<4xi32>) -> tensor<1x1x2x4xi32> + %0 = "mhlo.broadcast_in_dim"(%cst0) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<4xi32>) -> tensor<1x1x2x4xi32> // CHECK: %[[MUL:.*]] = mhlo.multiply %[[BROADCAST]], %[[ARG]] : tensor<1x1x2x4xi32> %1 = mhlo.multiply %0, %arg0 : tensor<1x1x2x4xi32> // CHECK: return %[[MUL]] : tensor<1x1x2x4xi32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/fuse_mhlo_convolution.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/fuse_mhlo_convolution.mlir index 042defb58dda00..98e97d1ef7b4c0 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/fuse_mhlo_convolution.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/fuse_mhlo_convolution.mlir @@ -5,13 +5,13 @@ func.func @fuseMulAndConv2D(%input: tensor<1x256x256x3xf32>) -> (tensor<1x256x256x2xf32>) { // CHECK-DAG: %[[FILTER:.+]] = mhlo.constant dense<{{\[\[\[\[}}1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00], [5.000000e+00, 6.000000e+00]]]]> : tensor<1x1x3x2xf32> // CHECK-DAG: %[[CST:.+]] = mhlo.constant dense<[1.000000e-01, 2.000000e-01]> : tensor<2xf32> - // CHECK-DAG: %[[CST_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[CST]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<2xf32>) -> tensor<1x1x3x2xf32> + // CHECK-DAG: %[[CST_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[CST]]) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<2xf32>) -> tensor<1x1x3x2xf32> // CHECK-DAG: %[[NEW_FILTER:.+]] = mhlo.multiply %[[CST_BCAST]], %[[FILTER]] : tensor<1x1x3x2xf32> // CHECK-DAG: %[[RESULT:.+]] = mhlo.convolution(%[[INPUT]], %[[NEW_FILTER]]) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x256x256x3xf32>, tensor<1x1x3x2xf32>) -> tensor<1x256x256x2xf32> %filter = mhlo.constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]]]> : tensor<1x1x3x2xf32> %cst = mhlo.constant dense<[0.1, 0.2]> : tensor<2xf32> %0 = mhlo.convolution(%input, %filter) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x256x256x3xf32>, tensor<1x1x3x2xf32>) -> tensor<1x256x256x2xf32> - %1 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<2xf32>) -> tensor<1x256x256x2xf32> + %1 = "mhlo.broadcast_in_dim"(%cst) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<2xf32>) -> tensor<1x256x256x2xf32> %2 = mhlo.multiply %0, %1 : tensor<1x256x256x2xf32> // CHECK-DAG: return %[[RESULT]] func.return %2 : tensor<1x256x256x2xf32> @@ -25,20 +25,20 @@ func.func @fuseMulAndConv2DDynamic(%input: tensor) -> (tensor : tensor<1x1x3x2xf32> // CHECK-DAG: %[[CST_0:.+]] = mhlo.constant dense<[1.000000e-01, 2.000000e-01]> : tensor<2xf32> // CHECK-DAG: %[[CST_1:.+]] = mhlo.constant dense<[3.000000e-01, 4.000000e-01]> : tensor<2xf32> - // CHECK: %[[CST_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[CST_0]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<2xf32>) -> tensor<1x1x3x2xf32> + // CHECK: %[[CST_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[CST_0]]) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<2xf32>) -> tensor<1x1x3x2xf32> // CHECK: %[[NEW_FILTER:.+]] = mhlo.multiply %[[CST_BCAST]], %[[FILTER]] : tensor<1x1x3x2xf32> // CHECK: %[[CONV:.+]] = mhlo.convolution(%[[INPUT]], %[[NEW_FILTER]]) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<1x1x3x2xf32>) -> tensor // CHECK: %[[SHAPE:.+]] = shape.shape_of %[[CONV]] : tensor -> tensor<4xindex> - // CHECK: %[[DYNAMIC_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[CST_1]], %[[SHAPE]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<2xf32>, tensor<4xindex>) -> tensor + // CHECK: %[[DYNAMIC_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[CST_1]], %[[SHAPE]]) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<2xf32>, tensor<4xindex>) -> tensor // CHECK: %[[ADD:.+]] = mhlo.add %[[CONV]], %[[DYNAMIC_BCAST]] : tensor %filter = mhlo.constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]]]> : tensor<1x1x3x2xf32> %cst_0 = mhlo.constant dense<[0.1, 0.2]> : tensor<2xf32> %cst_1 = mhlo.constant dense<[0.3, 0.4]> : tensor<2xf32> %0 = mhlo.convolution(%input, %filter) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<1x1x3x2xf32>) -> tensor %1 = shape.shape_of %0 : tensor -> tensor<4xindex> - %2 = "mhlo.dynamic_broadcast_in_dim"(%cst_0, %1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<2xf32>, tensor<4xindex>) -> tensor + %2 = "mhlo.dynamic_broadcast_in_dim"(%cst_0, %1) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<2xf32>, tensor<4xindex>) -> tensor %3 = mhlo.multiply %0, %2 : tensor - %4 = "mhlo.dynamic_broadcast_in_dim"(%cst_1, %1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<2xf32>, tensor<4xindex>) -> tensor + %4 = "mhlo.dynamic_broadcast_in_dim"(%cst_1, %1) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<2xf32>, tensor<4xindex>) -> tensor %5 = mhlo.add %3, %4 : tensor // CHECK-DAG: return %[[ADD]] func.return %5 : tensor diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-skip-partitioned-calls.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-skip-partitioned-calls.mlir new file mode 100644 index 00000000000000..79ccccff831531 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-skip-partitioned-calls.mlir @@ -0,0 +1,34 @@ +// RUN: odml-to-stablehlo-opt %s --tf-stablehlo=skip-partitioned-calls=true | FileCheck %s --check-prefix=CHECK-SKIP +// RUN: odml-to-stablehlo-opt %s --tf-stablehlo=skip-partitioned-calls=false | FileCheck %s --check-prefix=CHECK-NOSKIP + +module { + func.func @partitioned_call(%arg0: tensor<1x2x2x3xf32>) -> (tensor<1x2x2x3xf32>) { + %0 = "tf.StatefulPartitionedCall"(%arg0) <{ + config = "", config_proto = "", executor_type = "", f = @some_func + }> { + _collective_manager_ids = [], device = "" + } : (tensor<1x2x2x3xf32>) -> tensor<1x2x2x3xf32> + // CHECK-SKIP: tf.StatefulPartitionedCall + // CHECK-NOSKIP: call @some_func + // CHECK-NOSKIP-NOT: tf.StatefulPartitionedCall + %1 = "tf.PartitionedCall"(%0) <{ + config = "", config_proto = "", executor_type = "", f = @some_other_func + }> { + _collective_manager_ids = [], device = "" + } : (tensor<1x2x2x3xf32>) -> tensor<1x2x2x3xf32> + // CHECK-SKIP: tf.PartitionedCall + // CHECK-NOSKIP: call @some_other_func + // CHECK-NOSKIP-NOT: tf.PartitionedCall + func.return %1: tensor<1x2x2x3xf32> + } + + // CHECK-SKIP: func.func private @some_func + func.func private @some_func(%arg0: tensor<1x2x2x3xf32>) -> tensor<1x2x2x3xf32> attributes {tf._noinline = true} { + return %arg0 : tensor<1x2x2x3xf32> + } + + // CHECK-SKIP: func.func private @some_other_func + func.func private @some_other_func(%arg0: tensor<1x2x2x3xf32>) -> tensor<1x2x2x3xf32> attributes {tf._noinline = true} { + return %arg0 : tensor<1x2x2x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-skip-stateful-partition-calls.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-skip-stateful-partition-calls.mlir deleted file mode 100644 index 032bf414edfe0f..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-skip-stateful-partition-calls.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s --tf-stablehlo=skip-stateful-partitioned-call=true | FileCheck %s --check-prefix=CHECK-SKIP -// RUN: odml-to-stablehlo-opt %s --tf-stablehlo=skip-stateful-partitioned-call=false | FileCheck %s --check-prefix=CHECK-NOSKIP - -module { - func.func @partitioned_call(%arg0: tensor<1x2x2x3xf32>) -> (tensor<1x2x2x3xf32>) { - %0 = "tf.StatefulPartitionedCall"(%arg0) <{ - config = "", config_proto = "", executor_type = "", f = @some_func - }> { - _collective_manager_ids = [], device = "" - } : (tensor<1x2x2x3xf32>) -> tensor<1x2x2x3xf32> - // CHECK-SKIP: tf.StatefulPartitionedCall - // CHECK-NOSKIP-NOT: tf.StatefulPartitionedCall - func.return %0: tensor<1x2x2x3xf32> - } - - func.func private @some_func(%arg0: tensor<1x2x2x3xf32>) -> tensor<1x2x2x3xf32> attributes {tf._noinline = true} { - return %arg0 : tensor<1x2x2x3xf32> - } -} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir index 41a94b929c0f47..268247e815faa3 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir @@ -4,23 +4,10 @@ module { func.func public @main(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<1x100x32x4xf32>, %arg3: tensor<1x500x4x4xf32>, %arg4: tensor<1x500x4x4xf32>, %arg5: tensor<1x1x100x500xf32>, %arg6: tensor) - -> (tensor<3x3xf32>, tensor<1x100x32x4xf32>) { - // CHECK-ROUNDTRIP: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "odml.update_kv_cache", custom_option = #tfl} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - // CHECK-ROUNDTRIP: %1 = "tfl.custom"(%arg2, %arg3, %arg4, %arg5, %arg6) {custom_code = "odml.scaled_dot_product_attention", custom_option = #tfl} : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> - %0 = func.call @test_kv_cache(%arg0, %arg1) : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - %1 = func.call @test_sdpa(%arg2, %arg3, %arg4, %arg5, %arg6) : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> - return %0, %1 : tensor<3x3xf32>, tensor<1x100x32x4xf32> - } - - // CHECK-LABEL: func.func private @test_kv_cache - func.func private @test_kv_cache(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> { - // CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "odml.update_kv_cache", custom_option = #tfl} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - %0 = stablehlo.composite "odml.update_kv_cache" %arg0, %arg1 {composite_attributes = {kv_cache_max = 500 : i64}, decomposition = @odml.update_kv_cache.impl} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - return %0 : tensor<3x3xf32> - } - func.func private @odml.update_kv_cache.impl(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> { - // No decomposition provided for test case. - return %arg0 : tensor<3x3xf32> + -> tensor<1x100x32x4xf32> { + // CHECK-ROUNDTRIP: %0 = "tfl.custom"(%arg2, %arg3, %arg4, %arg5, %arg6) {custom_code = "odml.scaled_dot_product_attention", custom_option = #tfl} : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> + %0 = func.call @test_sdpa(%arg2, %arg3, %arg4, %arg5, %arg6) : (tensor<1x100x32x4xf32>, tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<1x1x100x500xf32>, tensor) -> tensor<1x100x32x4xf32> + return %0: tensor<1x100x32x4xf32> } // CHECK-LABEL: func.func private @test_sdpa @@ -34,4 +21,30 @@ module { return %arg0 : tensor<1x100x32x4xf32> } + // CHECK-LABEL: func.func private @test_multiple_kv_caches + func.func private @test_multiple_kv_caches(%arg0: tensor<1x500x4x4xf32>, %arg1: tensor<1x500x4x4xf32>, %arg2: tensor<100xi64>, %arg3: tensor<1x100x4x4xf32>, %arg4: tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) { + // CHECK: %0:2 = "tfl.custom"(%arg2, %arg3, %arg4) {custom_code = "odml.update_kv_cache", custom_option = #tfl} : (tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) + // CHECK: %1:2 = "tfl.custom"(%arg2, %arg3, %arg4) {custom_code = "odml.update_kv_cache", custom_option = #tfl} : (tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) + %0:2 = stablehlo.composite "odml.update_kv_cache" %arg0, %arg1, %arg2, %arg3, %arg4 {composite_attributes = {kv_cache_max = 500 : i64}, decomposition = @odml.update_kv_cache.impl_0} : (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) + %1:2 = stablehlo.composite "odml.update_kv_cache" %0#0, %0#1, %arg2, %arg3, %arg4 {composite_attributes = {kv_cache_max = 500 : i64}, decomposition = @odml.update_kv_cache.impl_0} : (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) + return %1#0, %1#1 : tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32> + } + func.func private @odml.update_kv_cache.impl_0(%arg0: tensor<1x500x4x4xf32>, %arg1: tensor<1x500x4x4xf32>, %arg2: tensor<100xi64>, %arg3: tensor<1x100x4x4xf32>, %arg4: tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) { + %0 = stablehlo.constant dense<500> : tensor<100xi64> + %1 = stablehlo.constant dense<0> : tensor<100xi64> + %2 = stablehlo.compare LT, %arg2, %1 : (tensor<100xi64>, tensor<100xi64>) -> tensor<100xi1> + %3 = stablehlo.add %arg2, %0 : tensor<100xi64> + %4 = stablehlo.select %2, %3, %arg2 : tensor<100xi1>, tensor<100xi64> + %5 = stablehlo.reshape %4 : (tensor<100xi64>) -> tensor<100x1xi64> + %6 = "stablehlo.scatter"(%arg0, %5, %arg3) ({ + ^bb0(%arg5: tensor, %arg6: tensor): + stablehlo.return %arg6 : tensor + }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false} : (tensor<1x500x4x4xf32>, tensor<100x1xi64>, tensor<1x100x4x4xf32>) -> tensor<1x500x4x4xf32> + %7 = "stablehlo.scatter"(%arg1, %5, %arg4) ({ + ^bb0(%arg5: tensor, %arg6: tensor): + stablehlo.return %arg6 : tensor + }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false} : (tensor<1x500x4x4xf32>, tensor<100x1xi64>, tensor<1x100x4x4xf32>) -> tensor<1x500x4x4xf32> + return %6, %7 : tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32> + } + } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-constant.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-constant.mlir index d4d3b0abf01de3..f5d5a734fad666 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-constant.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-constant.mlir @@ -9,7 +9,7 @@ module { // CHECK: module { // CHECK-NEXT: func @main() -> tensor<1xi64> { -// CHECK-NEXT: %0 = stablehlo.constant dense<2> : tensor<1xi64> -// CHECK-NEXT: return %0 : tensor<1xi64> +// CHECK-NEXT: %[[c0:.+]] = stablehlo.constant dense<2> : tensor<1xi64> +// CHECK-NEXT: return %[[c0]] : tensor<1xi64> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir index 8b454fa898e5a6..2329f68b36fc33 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir @@ -55,7 +55,7 @@ func.func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x1000xf32>, tensor<1x1000xf32> // CHECK: } func.func @broadcast_add(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1000xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<1x1000xf32> %1 = mhlo.add %0, %arg1 : tensor<1x1000xf32> %2 = mhlo.add %arg1, %0 : tensor<1x1000xf32> func.return %1, %2 : tensor<1x1000xf32>, tensor<1x1000xf32> @@ -109,7 +109,7 @@ func.func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x1000xf32>, tensor<1x1000xf32> // CHECK: } func.func @broadcast_div(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1000xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<1x1000xf32> %1 = mhlo.divide %0, %arg1 : tensor<1x1000xf32> %2 = mhlo.divide %arg1, %0 : tensor<1x1000xf32> func.return %1, %2 : tensor<1x1000xf32>, tensor<1x1000xf32> @@ -146,7 +146,7 @@ func.func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi3 // CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<4xi32>, tensor<4xi32> // CHECK: } func.func @broadcast_shift_left(%arg0: tensor<1xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0]> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<4xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0]> : tensor<1xi64>}> : (tensor<1xi32>) -> tensor<4xi32> %1 = mhlo.shift_left %0, %arg1 : tensor<4xi32> %2 = mhlo.shift_left %arg1, %0 : tensor<4xi32> func.return %1, %2 : tensor<4xi32>, tensor<4xi32> @@ -183,7 +183,7 @@ func.func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> // CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x1000xf32>, tensor<1x1000xf32> // CHECK: } func.func @broadcast_maximum(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1000xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<1x1000xf32> %1 = mhlo.maximum %0, %arg1 : tensor<1x1000xf32> %2 = mhlo.maximum %arg1, %0 : tensor<1x1000xf32> func.return %1, %2 : tensor<1x1000xf32>, tensor<1x1000xf32> @@ -209,7 +209,7 @@ func.func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> // CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x1000xf32>, tensor<1x1000xf32> // CHECK: } func.func @broadcast_minimum(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1000xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<1x1000xf32> %1 = mhlo.minimum %0, %arg1 : tensor<1x1000xf32> %2 = mhlo.minimum %arg1, %0 : tensor<1x1000xf32> func.return %1, %2 : tensor<1x1000xf32>, tensor<1x1000xf32> @@ -234,7 +234,7 @@ func.func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x1000xf32>, tensor<1x1000xf32> // CHECK: } func.func @broadcast_mul(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1000xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<1x1000xf32> %1 = mhlo.multiply %0, %arg1 : tensor<1x1000xf32> %2 = mhlo.multiply %arg1, %0 : tensor<1x1000xf32> func.return %1, %2 : tensor<1x1000xf32>, tensor<1x1000xf32> @@ -291,7 +291,7 @@ func.func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x1000xf32>, tensor<1x1000xf32> // CHECK: } func.func @broadcast_sub(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1000xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<1x1000xf32> %1 = mhlo.subtract %0, %arg1 : tensor<1x1000xf32> %2 = mhlo.subtract %arg1, %0 : tensor<1x1000xf32> func.return %1, %2 : tensor<1x1000xf32>, tensor<1x1000xf32> @@ -317,7 +317,7 @@ func.func @broadcast_sub_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> t // CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<4xf32>, tensor<4xf32> // CHECK: } func.func @broadcast_atan2(%arg0: tensor<1xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0]> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<4xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0]> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<4xf32> %1 = mhlo.atan2 %0, %arg1 : tensor<4xf32> %2 = mhlo.atan2 %arg1, %0 : tensor<4xf32> func.return %1, %2 : tensor<4xf32>, tensor<4xf32> @@ -429,7 +429,7 @@ func.func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi3 // CHECK: return %[[VAL_2]] : tensor<1x4xi8> // CHECK: } func.func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1xi8>) -> tensor<1x4xi8> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>}> : (tensor<1xi8>) -> tensor<1x4xi8> %1 = mhlo.or %0, %arg1 : tensor<1x4xi8> func.return %1 : tensor<1x4xi8> } @@ -474,7 +474,7 @@ func.func @bitwise_xor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi // CHECK: return %[[VAL_2]] : tensor<1x4xi8> // CHECK: } func.func @bitwise_xor_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>) -> tensor<1x4xi8> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<1xi8>) -> tensor<1x4xi8> %1 = mhlo.xor %0, %arg1 : tensor<1x4xi8> func.return %1 : tensor<1x4xi8> } @@ -497,7 +497,7 @@ func.func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi // CHECK: return %[[VAL_2]] : tensor<1x4xi8> // CHECK: } func.func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1xi8>) -> tensor<1x4xi8> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>}> : (tensor<1xi8>) -> tensor<1x4xi8> %1 = mhlo.and %0, %arg1 : tensor<1x4xi8> func.return %1 : tensor<1x4xi8> } @@ -543,7 +543,7 @@ func.func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<4xi32>, tensor<4xi32> // CHECK: } func.func @broadcast_pow(%arg0: tensor<1xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0]> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<4xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0]> : tensor<1xi64>}> : (tensor<1xi32>) -> tensor<4xi32> %1 = mhlo.power %0, %arg1 : tensor<4xi32> %2 = mhlo.power %arg1, %0 : tensor<4xi32> func.return %1, %2 : tensor<4xi32>, tensor<4xi32> @@ -695,7 +695,7 @@ func.func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor // CHECK: } func.func @equal_broadcast(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xi32>) -> tensor<1x2xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi32>) -> tensor<1x2xi32> %1 = "mhlo.compare"(%0, %arg1) {comparison_direction = #mhlo} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %1 : tensor<1x2xi1> } @@ -758,7 +758,7 @@ func.func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @notequal_broadcast(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xi32>) -> tensor<1x2xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi32>) -> tensor<1x2xi32> %1 = "mhlo.compare"(%0, %arg1) {comparison_direction = #mhlo} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %1 : tensor<1x2xi1> } @@ -814,7 +814,7 @@ func.func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @broadcast_greater(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xi32>) -> tensor<1x2xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi32>) -> tensor<1x2xi32> %1 = "mhlo.compare"(%0, %arg1) {comparison_direction = #mhlo} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %1 : tensor<1x2xi1> } @@ -855,7 +855,7 @@ func.func @greater_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2 // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @broadcast_greater_equal(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xi32>) -> tensor<1x2xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi32>) -> tensor<1x2xi32> %1 = "mhlo.compare"(%0, %arg1) {comparison_direction = #mhlo} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %1 : tensor<1x2xi1> } @@ -889,7 +889,7 @@ func.func @less(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @broadcast_less(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xi32>) -> tensor<1x2xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi32>) -> tensor<1x2xi32> %1 = "mhlo.compare"(%0, %arg1) {comparison_direction = #mhlo} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %1 : tensor<1x2xi1> } @@ -923,7 +923,7 @@ func.func @less_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1 // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @broadcast_less_equal(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xi32>) -> tensor<1x2xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi32>) -> tensor<1x2xi32> %1 = "mhlo.compare"(%0, %arg1) {comparison_direction = #mhlo} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %1 : tensor<1x2xi1> } @@ -947,7 +947,7 @@ func.func @broadcast_less_equal_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32 // CHECK: return %[[VAL_3]] : tensor<6x3xf32> // CHECK: } func.func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { - %2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> + %2 = "mhlo.concatenate"(%arg0, %arg1) <{dimension = 0 : i64}> : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> func.return %2 : tensor<6x3xf32> } @@ -959,7 +959,7 @@ func.func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6 // CHECK: return %[[VAL_3]] : tensor<3x6xf32> // CHECK: } func.func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { - %2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> + %2 = "mhlo.concatenate"(%arg0, %arg1) <{dimension = 1 : i64}> : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> func.return %2 : tensor<3x6xf32> } @@ -1113,7 +1113,7 @@ func.func @selectv2_pred_scalar(%arg0: tensor, %arg1: tensor<2xi32>, %arg2: // CHECK: return %[[VAL_3]] : tensor<1x100xi32> // CHECK: } func.func @selectv2_broadcasted_operand(%arg0: tensor, %arg1: tensor<1x1xi32>, %arg2: tensor<1x100xi32>) -> tensor<1x100xi32> { - %0 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xi32>) -> tensor<1x100xi32> + %0 = "mhlo.broadcast_in_dim"(%arg1) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi32>) -> tensor<1x100xi32> %1 = "mhlo.select"(%arg0, %0, %arg2) : (tensor, tensor<1x100xi32>, tensor<1x100xi32>) -> tensor<1x100xi32> func.return %1 : tensor<1x100xi32> } @@ -1126,7 +1126,7 @@ func.func @selectv2_broadcasted_operand(%arg0: tensor, %arg1: tensor<1x1xi32 // CHECK: return %[[VAL_3]] : tensor<1x100xi32> // CHECK: } func.func @selectv2_broadcasted_condition(%arg0: tensor<1x1xi1>, %arg1: tensor<1x100xi32>, %arg2: tensor<1x100xi32>) -> tensor<1x100xi32> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xi1>) -> tensor<1x100xi1> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi1>) -> tensor<1x100xi1> %1 = "mhlo.select"(%0, %arg1, %arg2) : (tensor<1x100xi1>, tensor<1x100xi32>, tensor<1x100xi32>) -> tensor<1x100xi32> func.return %1 : tensor<1x100xi32> } @@ -1142,7 +1142,7 @@ func.func @selectv2_broadcasted_condition(%arg0: tensor<1x1xi1>, %arg1: tensor<1 func.func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64> %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64> - %2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32> + %2 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<2x3xf32>) -> tensor<3x2xf32> func.return %2 : tensor<3x2xf32> } @@ -1157,7 +1157,7 @@ func.func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { func.func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { %0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi32> %1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64> - %2 = "mhlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> + %2 = "mhlo.transpose"(%arg0) <{permutation = dense<[2, 1, 0]> : tensor<3xi64>}> : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> func.return %2 : tensor<3x2x1xf32> } @@ -1172,7 +1172,7 @@ func.func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { func.func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { %0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64> %1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64> - %2 = "mhlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> + %2 = "mhlo.transpose"(%arg0) <{permutation = dense<[2, 1, 0]> : tensor<3xi64>}> : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> func.return %2 : tensor<3x2x1xf32> } @@ -1187,7 +1187,7 @@ func.func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { func.func @transpose_dynamic_2d(%arg0: tensor) -> tensor<4x?xf32> { %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64> %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64> - %2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor) -> tensor<4x?xf32> + %2 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor) -> tensor<4x?xf32> func.return %2 : tensor<4x?xf32> } @@ -1588,7 +1588,7 @@ func.func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> { // CHECK: return %[[VAL_4]] : tensor<1x519xf32> // CHECK: } func.func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> { - %0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32> + %0 = "mhlo.slice"(%arg0) <{limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<1x4672xf32>) -> tensor<1x519xf32> func.return %0 : tensor<1x519xf32> } @@ -1676,7 +1676,7 @@ func.func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) // CHECK: return %[[VAL_2]] : tensor<3x8x8x16xf32> // CHECK: } func.func @broadcast_in_dim_tf_style(%arg0: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"}> : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> func.return %0 : tensor<3x8x8x16xf32> } @@ -1689,7 +1689,7 @@ func.func @broadcast_in_dim_tf_style(%arg0: tensor<8x1x16xf32>) -> tensor<3x8x8x // CHECK: return %[[VAL_4]] : tensor<3x8x8x16xf32> // CHECK: } func.func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"}> : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> func.return %0 : tensor<3x8x8x16xf32> } @@ -1699,7 +1699,7 @@ func.func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x // CHECK %[[VAL_0:.*]] = "tf.BroadcastTo"(%[[ARG_0]], %[[ARG_1]]) : (tensor, tensor<5xi32>) -> tensor // CHECK return %[[VAL_0]] : tensor func.func @dynamic_broadcast_in_dim_tf_style(%arg0: tensor, %arg1: tensor<5xi32>) -> tensor { - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1, 2, 3, 4]> : tensor<5xi64>} : (tensor, tensor<5xi32>) -> tensor + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{broadcast_dimensions = dense<[0, 1, 2, 3, 4]> : tensor<5xi64>}> : (tensor, tensor<5xi32>) -> tensor func.return %0 : tensor } @@ -1713,7 +1713,7 @@ func.func @dynamic_broadcast_in_dim_tf_style(%arg0: tensor, %arg1 // CHECK %[[VAL_2:.*]] = "tf.BroadcastTo"(%[[VAL_1]], %[[ARG_1]]) : (tensor, tensor<4xi32>) -> tensor // CHECK return %[[VAL_2]] : tensor func.func @dynamic_broadcast_in_dim_general_case_expand_back_dims(%arg0: tensor, %arg1: tensor<4xi32>) -> tensor { - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<4xi32>) -> tensor + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor, tensor<4xi32>) -> tensor func.return %0 : tensor } @@ -1725,7 +1725,7 @@ func.func @dynamic_broadcast_in_dim_general_case_expand_back_dims(%arg0: tensor< // CHECK %[[VAL_1:.*]] = "tf.BroadcastTo"(%[[VAL_0]], %[[ARG_1]]) : (tensor, tensor<4xi32>) -> tensor // CHECK return %[[VAL_1]] : tensor func.func @dynamic_broadcast_in_dim_general_case_expand_middle_dim(%arg0: tensor, %arg1: tensor<4xi32>) -> tensor { - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1, 3]> : tensor<3xi64>} : (tensor, tensor<4xi32>) -> tensor + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{broadcast_dimensions = dense<[0, 1, 3]> : tensor<3xi64>}> : (tensor, tensor<4xi32>) -> tensor func.return %0 : tensor } @@ -3041,7 +3041,7 @@ func.func @convert_reduce_to_min_int(%arg0: tensor<1x4xi32>) -> tensor<1xi32> { // CHECK: return %[[VAL_3]] : tensor<123xf32> // CHECK: } func.func @convert_iota_1d() -> tensor<123xf32> { - %0 = "mhlo.iota"() { iota_dimension = 0 : i64 } : () -> tensor<123xf32> + %0 = "mhlo.iota"() <{ iota_dimension = 0 : i64 }> : () -> tensor<123xf32> func.return %0 : tensor<123xf32> } @@ -3057,7 +3057,7 @@ func.func @convert_iota_1d() -> tensor<123xf32> { // CHECK: return %[[VAL_7]] : tensor<5x7x9xi32> // CHECK: } func.func @convert_iota_3d() -> tensor<5x7x9xi32> { - %0 = "mhlo.iota"() { iota_dimension = 1 : i64 } : () -> tensor<5x7x9xi32> + %0 = "mhlo.iota"() <{ iota_dimension = 1 : i64 }> : () -> tensor<5x7x9xi32> func.return %0 : tensor<5x7x9xi32> } @@ -3091,7 +3091,7 @@ func.func @convert_avgpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8 func.func @convert_avgpool_valid_broadcasted_divisor(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> { %0 = mhlo.constant dense<0.0> : tensor %1 = mhlo.constant dense<9.0> : tensor - %2 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<4x7x7x8xf32> + %2 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4x7x7x8xf32> %3 = "mhlo.reduce_window"(%arg0, %0) ({ ^bb0(%arg1: tensor, %arg2: tensor): %5 = mhlo.add %arg1, %arg2 : tensor @@ -3167,7 +3167,7 @@ func.func @convert_avgpool_valid_rw(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x // CHECK: } func.func @convert_avgpool_valid_rw_broadcasted_const_lhs(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> { %0 = mhlo.constant dense<1.0> : tensor - %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<4x16x16x8xf32> + %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4x16x16x8xf32> %2 = mhlo.constant dense<0.0> : tensor %3 = "mhlo.reduce_window"(%arg0, %2) ({ ^bb0(%arg1: tensor, %arg2: tensor): @@ -3288,7 +3288,7 @@ func.func @convert_avgpool_reshape_broadcast(%arg0: tensor<4x16x16x8xf32>) -> te mhlo.return %7 : tensor }) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x16x16x1xf32>, tensor) -> tensor<1x8x8x1xf32> %4 = mhlo.reshape %3 : (tensor<1x8x8x1xf32>) -> tensor<8x8xf32> - %5 = "mhlo.broadcast_in_dim"(%4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x8xf32>) -> tensor<4x8x8x8xf32> + %5 = "mhlo.broadcast_in_dim"(%4) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> : (tensor<8x8xf32>) -> tensor<4x8x8x8xf32> %6 = mhlo.divide %2, %5 : tensor<4x8x8x8xf32> return %6 : tensor<4x8x8x8xf32> } @@ -3626,7 +3626,7 @@ func.func @convert_floor_div_broadcast_cst(%arg0: tensor<10x8xf32>) -> tensor<10 %1 = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00, 1.600000e+01, 3.200000e+01, 6.400000e+01, 1.280000e+02]> : tensor<8xf32> %2 = mhlo.constant dense<0.000000e+00> : tensor<10x8xf32> %3 = mhlo.constant dense<-1.000000e+00> : tensor<10x8xf32> - %5 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>) -> tensor<10x8xf32> + %5 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<8xf32>) -> tensor<10x8xf32> %6 = mhlo.remainder %arg0, %5 : tensor<10x8xf32> %7 = "mhlo.compare"(%6, %2) {comparison_direction = #mhlo} : (tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xi1> %8 = "mhlo.sign"(%6) : (tensor<10x8xf32>) -> tensor<10x8xf32> @@ -3890,7 +3890,7 @@ func.func @convert_gather_to_slice_dynamic_error(%arg0: tensor<3x?xi32>, %arg1: // CHECK: return %[[VAL_14]] : tensor<4x2xf32> // CHECK: } func.func @convert_dynamic_slice(%arg0: tensor<7x3xf32>, %arg1: tensor, %arg2: tensor) -> tensor<4x2xf32> { - %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[4, 2]> : tensor<2xi64>} : (tensor<7x3xf32>, tensor, tensor) -> tensor<4x2xf32> + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor, tensor) -> tensor<4x2xf32> func.return %0 : tensor<4x2xf32> } @@ -3913,7 +3913,7 @@ func.func @convert_dynamic_slice(%arg0: tensor<7x3xf32>, %arg1: tensor, %ar // CHECK: return %[[VAL_14]] : tensor<4x2xf32> // CHECK: } func.func @convert_dynamic_slice_ui32(%arg0: tensor<7x3xf32>, %arg1: tensor, %arg2: tensor) -> tensor<4x2xf32> { - %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[4, 2]> : tensor<2xi64>} : (tensor<7x3xf32>, tensor, tensor) -> tensor<4x2xf32> + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor, tensor) -> tensor<4x2xf32> func.return %0 : tensor<4x2xf32> } @@ -3921,10 +3921,12 @@ func.func @convert_dynamic_slice_ui32(%arg0: tensor<7x3xf32>, %arg1: tensor, // CHECK-SAME: %[[VAL_1:.*]]: tensor<4xi32>, // CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { -// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{ +// CHECK-SAME: indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false +// CHECK-SAME: }> ({ // CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): // CHECK: mhlo.return %[[VAL_5]] : tensor -// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<20x6xf32>, tensor<4xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> +// CHECK: }) : (tensor<20x6xf32>, tensor<4xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> // CHECK: return %[[VAL_3]] : tensor<20x6xf32> // CHECK: } func.func @convert_scatter_update(%arg0: tensor<20x6xf32>, %arg1: tensor<4xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> { @@ -3947,10 +3949,12 @@ func.func @convert_scatter_update(%arg0: tensor<20x6xf32>, %arg1: tensor<4xi32>, // CHECK-SAME: %[[VAL_0:.*]]: tensor<5x10xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<3x1xi32>, // CHECK-SAME: %[[VAL_2:.*]]: tensor<10x3xf32>) -> tensor<5x10xf32> { -// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{ +// CHECK-SAME: indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false +// CHECK-SAME: }> ({ // CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): // CHECK: mhlo.return %[[VAL_5]] : tensor -// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<5x10xf32>, tensor<3x1xi32>, tensor<10x3xf32>) -> tensor<5x10xf32> +// CHECK: }) : (tensor<5x10xf32>, tensor<3x1xi32>, tensor<10x3xf32>) -> tensor<5x10xf32> // CHECK: return %[[VAL_3]] : tensor<5x10xf32> // CHECK: } func.func @convert_scatter_update_with_non_trailing_update_window_dims( @@ -3977,10 +3981,12 @@ func.func @convert_scatter_update_with_non_trailing_update_window_dims( // CHECK-SAME: %[[VAL_0:.*]]: tensor<5x4x3x7xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x2xi32>, // CHECK-SAME: %[[VAL_2:.*]]: tensor<2x5x3xf32>) -> tensor<5x4x3x7xf32> { -// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{ +// CHECK-SAME: indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false +// CHECK-SAME: }> ({ // CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): // CHECK: mhlo.return %[[VAL_5]] : tensor -// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<5x4x3x7xf32>, tensor<2x2xi32>, tensor<2x5x3xf32>) -> tensor<5x4x3x7xf32> +// CHECK: }) : (tensor<5x4x3x7xf32>, tensor<2x2xi32>, tensor<2x5x3xf32>) -> tensor<5x4x3x7xf32> // CHECK: return %[[VAL_3]] : tensor<5x4x3x7xf32> // CHECK: } func.func @convert_scatter_update_to_non_trailing_operand_dimensions( @@ -4006,10 +4012,12 @@ func.func @convert_scatter_update_to_non_trailing_operand_dimensions( // CHECK-SAME: %[[VAL_0:.*]]: tensor<16x1504xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi32>, // CHECK-SAME: %[[VAL_2:.*]]: tensor<16xf32>) -> tensor<16x1504xf32> { -// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{ +// CHECK-SAME: indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true +// CHECK-SAME: }> ({ // CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): // CHECK: mhlo.return %[[VAL_5]] : tensor -// CHECK: }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<16x1504xf32>, tensor<1xi32>, tensor<16xf32>) -> tensor<16x1504xf32> +// CHECK: }) : (tensor<16x1504xf32>, tensor<1xi32>, tensor<16xf32>) -> tensor<16x1504xf32> // CHECK: return %[[VAL_3]] : tensor<16x1504xf32> // CHECK: } func.func @convert_scatter_update_reshape_indices_and_updates( @@ -4034,11 +4042,13 @@ func.func @convert_scatter_update_reshape_indices_and_updates( // CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>, // CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { -// CHECK: %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{ +// CHECK-SAME: indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false +// CHECK-SAME: }> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: tensor, %[[VAL_4:.*]]: tensor): // CHECK: %[[VAL_5:.*]] = mhlo.add %[[VAL_3]], %[[VAL_4]] : tensor // CHECK: mhlo.return %[[VAL_5]] : tensor -// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> +// CHECK: }) : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> // CHECK: return %[[VAL_6]] : tensor<20x6xf32> // CHECK: } func.func @convert_scatter_add(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> { @@ -4062,11 +4072,13 @@ func.func @convert_scatter_add(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, // CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>, // CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { -// CHECK: %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{ +// CHECK-SAME: indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false +// CHECK-SAME: }> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: tensor, %[[VAL_4:.*]]: tensor): // CHECK: %[[VAL_5:.*]] = mhlo.maximum %[[VAL_3]], %[[VAL_4]] : tensor // CHECK: mhlo.return %[[VAL_5]] : tensor -// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> +// CHECK: }) : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> // CHECK: return %[[VAL_6]] : tensor<20x6xf32> // CHECK: } func.func @convert_scatter_max(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> { @@ -4090,11 +4102,13 @@ func.func @convert_scatter_max(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, // CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>, // CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { -// CHECK: %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{ +// CHECK-SAME: indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false +// CHECK-SAME: }> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: tensor, %[[VAL_4:.*]]: tensor): // CHECK: %[[VAL_5:.*]] = mhlo.minimum %[[VAL_3]], %[[VAL_4]] : tensor // CHECK: mhlo.return %[[VAL_5]] : tensor -// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> +// CHECK: }) : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> // CHECK: return %[[VAL_6]] : tensor<20x6xf32> // CHECK: } func.func @convert_scatter_min(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> { @@ -4118,11 +4132,13 @@ func.func @convert_scatter_min(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, // CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>, // CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { -// CHECK: %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{ +// CHECK-SAME: indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false +// CHECK-SAME: }> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: tensor, %[[VAL_4:.*]]: tensor): // CHECK: %[[VAL_5:.*]] = mhlo.subtract %[[VAL_3]], %[[VAL_4]] : tensor // CHECK: mhlo.return %[[VAL_5]] : tensor -// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> +// CHECK: }) : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> // CHECK: return %[[VAL_6]] : tensor<20x6xf32> // CHECK: } func.func @convert_scatter_sub(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> { @@ -4146,7 +4162,7 @@ func.func @convert_scatter_sub(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, func.func @convert_pytorch_argmax(%arg0: tensor<1x9xi32>) -> tensor<1xi32> { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<-2147483648> : tensor - %2 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<9xi32> + %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<9xi32> %3 = mhlo.reshape %2 : (tensor<9xi32>) -> tensor<1x9xi32> %4:2 = mhlo.reduce(%arg0 init: %1), (%3 init: %0) across dimensions = [1] : (tensor<1x9xi32>, tensor<1x9xi32>, tensor, tensor) -> (tensor<1xi32>, tensor<1xi32>) reducer(%arg1: tensor, %arg3: tensor) (%arg2: tensor, %arg4: tensor) { @@ -4184,8 +4200,8 @@ func.func @convert_pytorch_argmax(%arg0: tensor<1x9xi32>) -> tensor<1xi32> { func.func @convert_argmax(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32xi32>) { %0 = mhlo.constant dense<0xFF800000> : tensor %1 = mhlo.constant dense<0> : tensor - %2 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<256xi32> - %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xi32>) -> tensor<4x32x256xi32> + %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32> + %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32> %4:2 = "mhlo.reduce"(%arg0, %3, %0, %1) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %7 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor @@ -4277,7 +4293,7 @@ func.func @convert_argmax_constant_non_z_axis(%arg0: tensor<4x4xf32>) -> (tensor // CHECK: } func.func @convert_argmax_bool(%arg0: tensor<2xi1>) -> tensor { - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2xi32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> %1 = mhlo.constant dense : tensor %2 = mhlo.constant dense<0> : tensor %3:2 = mhlo.reduce(%arg0 init: %1), (%0 init: %2) across dimensions = [0] : (tensor<2xi1>, tensor<2xi32>, tensor, tensor) -> (tensor, tensor) @@ -4304,8 +4320,8 @@ func.func @convert_argmax_bool(%arg0: tensor<2xi1>) -> tensor { func.func @convert_argmin(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32xi32>) { %0 = mhlo.constant dense<0x7F800000> : tensor %1 = mhlo.constant dense<0> : tensor - %2 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<256xi32> - %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xi32>) -> tensor<4x32x256xi32> + %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32> + %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32> %4:2 = "mhlo.reduce"(%arg0, %3, %0, %1) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %7 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor @@ -4331,7 +4347,7 @@ func.func @convert_argmin(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, ten // CHECK: } func.func @convert_argmin_i16(%arg0: tensor<2xi16>) -> (tensor, tensor) { %0 = mhlo.constant dense : tensor - %1 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2xi32> + %1 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> %2 = mhlo.constant dense<32767> : tensor %3 = mhlo.constant dense<0> : tensor %4:2 = "mhlo.reduce"(%arg0, %1, %2, %3) ({ @@ -4394,7 +4410,7 @@ func.func @convert_argmin_constant(%arg0: tensor<2x2x4xf32>) -> (tensor<2x2xf32> // CHECK: return %[[VAL_9]] : tensor // CHECK: } func.func @convert_argmin_bool(%arg0: tensor<2xi1>) -> tensor { - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2xi32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> %1 = mhlo.constant dense : tensor %2 = mhlo.constant dense<0> : tensor %3:2 = mhlo.reduce(%arg0 init: %1), (%0 init: %2) across dimensions = [0] : (tensor<2xi1>, tensor<2xi32>, tensor, tensor) -> (tensor, tensor) @@ -4429,7 +4445,7 @@ func.func @convert_argmin_bool(%arg0: tensor<2xi1>) -> tensor { func.func @convert_argmax_with_reshaped_iota(%arg0: tensor<1x32x1xf32>) -> (tensor<1x1xf32>, tensor<1x1xi32>) { %0 = mhlo.constant dense<0xFF800000> : tensor %1 = mhlo.constant dense<0> : tensor - %2 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<32xi32> + %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<32xi32> %3 = "mhlo.reshape"(%2) : (tensor<32xi32>) -> tensor<1x32x1xi32> %4:2 = mhlo.reduce(%arg0 init: %0), (%3 init: %1) across dimensions = [1] : (tensor<1x32x1xf32>, tensor<1x32x1xi32>, tensor, tensor) -> (tensor<1x1xf32>, tensor<1x1xi32>) reducer(%arg1: tensor, %arg3: tensor) (%arg2: tensor, %arg4: tensor) { @@ -4772,8 +4788,8 @@ func.func @convert_reduce_to_any_non_constant_init(%arg0: tensor, %arg1: ten // CHECK: return %[[VALUES]], %[[INDICES]] : tensor<3x6xf32>, tensor<3x6xi32> // CHECK: } func.func @convert_sort_to_topk_iota_broadcast(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) { - %0 = "mhlo.iota"() { iota_dimension = 0 : i64 } : () -> tensor<6xi32> - %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>, name = "broadcast.0"} : (tensor<6xi32>) -> tensor<3x6xi32> + %0 = "mhlo.iota"() <{ iota_dimension = 0 : i64 }> : () -> tensor<6xi32> + %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>, name = "broadcast.0"}> : (tensor<6xi32>) -> tensor<3x6xi32> %2:2 = "mhlo.sort"(%arg0, %1) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %3 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor @@ -4793,7 +4809,7 @@ func.func @convert_sort_to_topk_iota_broadcast(%arg0: tensor<3x6xf32>) -> (tenso // CHECK: } func.func @convert_sort_to_topk_iotacst_broadcast(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) { %0 = mhlo.constant dense<[0, 1, 2, 3, 4, 5]> : tensor<6xi32> - %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>, name = "broadcast.0"} : (tensor<6xi32>) -> tensor<3x6xi32> + %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>, name = "broadcast.0"}> : (tensor<6xi32>) -> tensor<3x6xi32> %2:2 = "mhlo.sort"(%arg0, %1) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %3 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor @@ -4996,7 +5012,7 @@ func.func @convert_dot_quant_type(%arg0: tensor<1x256xf32>, %arg1: tensor<256x!q // CHECK %[[CST_0:.*]] = "tf.Const"() <{value = dense<256> : tensor}> : () -> tensor // CHECK return %[[CST_0]] : tensor func.func @get_dimension_size(%arg0: tensor<4x256x?xf32>) -> tensor { - %0 = "mhlo.get_dimension_size"(%arg0) {dimension = 1 : i64} : (tensor<4x256x?xf32>) -> tensor + %0 = "mhlo.get_dimension_size"(%arg0) <{dimension = 1 : i64}> : (tensor<4x256x?xf32>) -> tensor func.return %0 : tensor } @@ -5009,7 +5025,7 @@ func.func @get_dimension_size(%arg0: tensor<4x256x?xf32>) -> tensor { // CHECK %[[VAL_2:.*]] = "tf.Squeeze"(%[[VAL_1]]) <{squeeze_dims = [0]}> : (tensor<1xi32>) -> tensor // CHECK return %[[VAL_2]] : tensor func.func @get_dimension_size_dynamic(%arg0: tensor<4x256x?xf32>) -> tensor { - %0 = "mhlo.get_dimension_size"(%arg0) {dimension = 2 : i64} : (tensor<4x256x?xf32>) -> tensor + %0 = "mhlo.get_dimension_size"(%arg0) <{dimension = 2 : i64}> : (tensor<4x256x?xf32>) -> tensor func.return %0 : tensor } @@ -5022,7 +5038,7 @@ func.func @get_dimension_size_dynamic(%arg0: tensor<4x256x?xf32>) -> tensor // CHECK: %[[VAL_1:.*]] = "tf.Range"(%[[CST_1]], %[[VAL_0]], %[[CST_2]]) : (tensor, tensor, tensor) -> tensor // CHECK: return %[[VAL_1]] : tensor func.func @dynamic_iota_i32_1d(%arg0: tensor<1xi32>) -> tensor { - %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor + %0 = "mhlo.dynamic_iota"(%arg0) <{iota_dimension = 0 : i64}> : (tensor<1xi32>) -> tensor func.return %0 : tensor } @@ -5036,7 +5052,7 @@ func.func @dynamic_iota_i32_1d(%arg0: tensor<1xi32>) -> tensor { // CHECK: %[[VAL_2:.*]] = "tf.Range"(%[[CST_1]], %[[VAL_1]], %[[CST_2]]) : (tensor, tensor, tensor) -> tensor // CHECK: return %[[VAL_2]] : tensor func.func @dynamic_iota_f32_1d(%arg0: tensor<1xi32>) -> tensor { - %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor + %0 = "mhlo.dynamic_iota"(%arg0) <{iota_dimension = 0 : i64}> : (tensor<1xi32>) -> tensor func.return %0 : tensor } @@ -5070,7 +5086,7 @@ func.return %0 : tensor<1x?x1x2xf32> // CHECK-NOT: "mhlo.custom_call" func.func @remove_shape_assertion_custom_call(%arg1: tensor) -> tensor { %0 = mhlo.constant dense<3> : tensor - %1 = "mhlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor + %1 = "mhlo.get_dimension_size"(%arg1) <{dimension = 0 : i64}> : (tensor) -> tensor %ok = mhlo.compare EQ, %1, %0, SIGNED : (tensor, tensor) -> tensor mhlo.custom_call @shape_assertion(%ok) { error_message = "The error message", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-allow-tf.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-allow-tf.mlir index 608b90d54a7f72..4596d21637b69e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-allow-tf.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-allow-tf.mlir @@ -3,8 +3,8 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 975 : i32}, tf_saved_model.semantics} { func.func @serving_default(%arg0: tensor<1x20x20x28xf32> {tf_saved_model.index_path = ["a"]}) -> (tensor<1x40x40x28xf32> {tf_saved_model.index_path = ["b"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "c:0", outputs = "d:0"}, tf_saved_model.exported_names = ["serving_default"]} { - %0 = stablehlo.constant dense<40> : tensor<2xi32> - %1 = "tf.UnconvertedOp"(%arg0, %0) {align_corners = false, half_pixel_centers = false} : (tensor<1x20x20x28xf32>, tensor<2xi32>) -> tensor<1x40x40x28xf32> - func.return %1 : tensor<1x40x40x28xf32> + %c = stablehlo.constant dense<40> : tensor<2xi32> + %0 = "tf.UnconvertedOp"(%arg0, %c) {align_corners = false, half_pixel_centers = false} : (tensor<1x20x20x28xf32>, tensor<2xi32>) -> tensor<1x40x40x28xf32> + func.return %0 : tensor<1x40x40x28xf32> } } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-smuggle-resize.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-smuggle-resize.mlir index 4a0f6a5d5e673b..eee056a8a489f9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-smuggle-resize.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-smuggle-resize.mlir @@ -5,8 +5,8 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 975 : i32}, tf_saved_model.semantics} { func.func @serving_default(%arg0: tensor<1x32x32x128xf32> {tf_saved_model.index_path = ["a"]}) -> (tensor<1x64x64x128xf32> {tf_saved_model.index_path = ["b"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "c:0", outputs = "d:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = "tf.Const"() {value = dense<[56, 904]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: %1 = stablehlo.custom_call @tf.ResizeBilinear(%arg0, %0) {align_corners = false, device = "", half_pixel_centers = true} : (tensor<1x32x32x128xf32>, tensor<2xi32>) -> tensor<1x64x64x128xf32> - // CHECK-OPT: %0 = stablehlo.custom_call @tf.ResizeBilinear(%arg0, %cst) {align_corners = false, device = "", half_pixel_centers = true} : (tensor<1x32x32x128xf32>, tensor<2xi32>) -> tensor<1x64x64x128xf32> + // CHECK: %{{.*}} = stablehlo.custom_call @tf.ResizeBilinear(%arg0, %{{.*}}) {align_corners = false, device = "", half_pixel_centers = true} : (tensor<1x32x32x128xf32>, tensor<2xi32>) -> tensor<1x64x64x128xf32> + // CHECK-OPT: %{{.*}} = stablehlo.custom_call @tf.ResizeBilinear(%arg0, %cst) {align_corners = false, device = "", half_pixel_centers = true} : (tensor<1x32x32x128xf32>, tensor<2xi32>) -> tensor<1x64x64x128xf32> %1 = "tf.ResizeBilinear"(%arg0, %0) { align_corners = false, device = "", half_pixel_centers = true } : (tensor<1x32x32x128xf32>, tensor<2xi32>) -> tensor<1x64x64x128xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir index 722fc5b47459f8..7db47a1a3e7703 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir @@ -5,11 +5,11 @@ func.func @testDotToDotGeneralVectorVector(%arg0: tensor<3072xf32>, %arg1: tenso %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<3072xf32>, tensor<3072xf32>) -> tensor func.return %0 : tensor -// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) { +// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_contracting_dimensions = [0], // CHECK-SAME: rhs_contracting_dimensions = [0] -// CHECK-SAME: >} : (tensor<3072xf32>, tensor<3072xf32>) -> tensor +// CHECK-SAME: >}> : (tensor<3072xf32>, tensor<3072xf32>) -> tensor // CHECK: return %[[RES]] : tensor } @@ -20,11 +20,11 @@ func.func @testDotToDotGeneralVectorMatrix(%arg0: tensor<3072xf32>, %arg1: tenso %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<3072xf32>, tensor<3072x512xf32>) -> tensor<512xf32> func.return %0 : tensor<512xf32> -// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) { +// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_contracting_dimensions = [0], // CHECK-SAME: rhs_contracting_dimensions = [0] -// CHECK-SAME: >} : (tensor<3072xf32>, tensor<3072x512xf32>) -> tensor<512xf32> +// CHECK-SAME: >}> : (tensor<3072xf32>, tensor<3072x512xf32>) -> tensor<512xf32> // CHECK: return %[[RES]] : tensor<512xf32> } @@ -35,11 +35,11 @@ func.func @testDotToDotGeneralMatrixVector(%arg0: tensor<2x3072xf32>, %arg1: ten %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3072xf32>, tensor<3072xf32>) -> tensor<2xf32> func.return %0 : tensor<2xf32> -// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) { +// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_contracting_dimensions = [1], // CHECK-SAME: rhs_contracting_dimensions = [0] -// CHECK-SAME: >} : (tensor<2x3072xf32>, tensor<3072xf32>) -> tensor<2xf32> +// CHECK-SAME: >}> : (tensor<2x3072xf32>, tensor<3072xf32>) -> tensor<2xf32> // CHECK: return %[[RES]] : tensor<2xf32> } @@ -50,11 +50,11 @@ func.func @testDotToDotGeneralMatrixMatrix(%arg0: tensor<2x3072xf32>, %arg1: ten %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3072xf32>, tensor<3072x512xf32>) -> tensor<2x512xf32> func.return %0 : tensor<2x512xf32> -// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) { +// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_contracting_dimensions = [1], // CHECK-SAME: rhs_contracting_dimensions = [0] -// CHECK-SAME: >} : (tensor<2x3072xf32>, tensor<3072x512xf32>) -> tensor<2x512xf32> +// CHECK-SAME: >}> : (tensor<2x3072xf32>, tensor<3072x512xf32>) -> tensor<2x512xf32> // CHECK: return %[[RES]] : tensor<2x512xf32> } @@ -73,13 +73,13 @@ func.func @testRemoveReshapeAroundDotGeneral(%arg0: tensor<3x72x1x2048xf32>, %ar %2 = "mhlo.reshape"(%1) : (tensor<3x72x512xf32>) -> tensor<3x72x1x512xf32> func.return %2 : tensor<3x72x1x512xf32> -// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) { +// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_batching_dimensions = [0], // CHECK-SAME: rhs_batching_dimensions = [0], // CHECK-SAME: lhs_contracting_dimensions = [3], // CHECK-SAME: rhs_contracting_dimensions = [1] -// CHECK-SAME: >} : (tensor<3x72x1x2048xf32>, tensor<3x2048x512xf32>) -> tensor<3x72x1x512xf32> +// CHECK-SAME: >}> : (tensor<3x72x1x2048xf32>, tensor<3x2048x512xf32>) -> tensor<3x72x1x512xf32> // CHECK: return %[[RES]] : tensor<3x72x1x512xf32> } @@ -92,11 +92,11 @@ func.func @testRemoveReshapeAroundDot(%arg0: tensor<1x1x512xf32>, %arg1: tensor< %2 = "mhlo.reshape"(%1) : (tensor<1x13xf32>) -> tensor<1x1x13xf32> func.return %2 : tensor<1x1x13xf32> -// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) { +// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_contracting_dimensions = [2], // CHECK-SAME: rhs_contracting_dimensions = [0] -// CHECK-SAME: >} : (tensor<1x1x512xf32>, tensor<512x13x!quant.uniform>) -> tensor<1x1x13xf32> +// CHECK-SAME: >}> : (tensor<1x1x512xf32>, tensor<512x13x!quant.uniform>) -> tensor<1x1x13xf32> // CHECK: return %[[RES]] : tensor<1x1x13xf32> } @@ -105,15 +105,15 @@ func.func @testRemoveReshapeAroundDot(%arg0: tensor<1x1x512xf32>, %arg1: tensor< // CHECK-LABEL: testTwoConsecutivePads func.func @testTwoConsecutivePads(%arg0: tensor<10x10x10xf32>) -> (tensor<12x12x12xf32>) { %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<0> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<11x11x11xf32> + %1 = "mhlo.pad"(%arg0, %0) <{edge_padding_high = dense<0> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<10x10x10xf32>, tensor) -> tensor<11x11x11xf32> %2 = mhlo.constant dense<0.000000e+00> : tensor - %3 = "mhlo.pad"(%1, %2) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<11x11x11xf32>, tensor) -> tensor<12x12x12xf32> + %3 = "mhlo.pad"(%1, %2) <{edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<11x11x11xf32>, tensor) -> tensor<12x12x12xf32> return %3 : tensor<12x12x12xf32> -// CHECK: %[[RES:.*]] = "mhlo.pad"(%arg0, %0) { +// CHECK: %[[RES:.*]] = "mhlo.pad"(%arg0, %0) <{ // CHECK-SAME: edge_padding_high = dense<1> : tensor<3xi64>, // CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> -// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> +// CHECK-SAME: }> : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> // CHECK: return %[[RES]] : tensor<12x12x12xf32> } @@ -122,16 +122,16 @@ func.func @testTwoConsecutivePads(%arg0: tensor<10x10x10xf32>) -> (tensor<12x12x // CHECK-LABEL: testTwoConsecutivePadsNegativeLowPad func.func @testTwoConsecutivePadsNegativeLowPad(%arg0: tensor<10x10x10xf32>) -> (tensor<10x10x10xf32>) { %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<0> : tensor<3xi64>, edge_padding_low = dense<-1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<9x9x9xf32> + %1 = "mhlo.pad"(%arg0, %0) <{edge_padding_high = dense<0> : tensor<3xi64>, edge_padding_low = dense<-1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<10x10x10xf32>, tensor) -> tensor<9x9x9xf32> %2 = mhlo.constant dense<0.000000e+00> : tensor - %3 = "mhlo.pad"(%1, %2) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<9x9x9xf32>, tensor) -> tensor<10x10x10xf32> + %3 = "mhlo.pad"(%1, %2) <{edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<9x9x9xf32>, tensor) -> tensor<10x10x10xf32> return %3 : tensor<10x10x10xf32> -// CHECK: %[[RES:.*]] = "mhlo.pad"(%arg0, %0) { +// CHECK: %[[RES:.*]] = "mhlo.pad"(%arg0, %0) <{ // CHECK-SAME: edge_padding_high = dense<1> : tensor<3xi64>, // CHECK-SAME: edge_padding_low = dense<-1> : tensor<3xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> -// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<10x10x10xf32> +// CHECK-SAME: }> : (tensor<10x10x10xf32>, tensor) -> tensor<10x10x10xf32> // CHECK: return %[[RES]] : tensor<10x10x10xf32> } @@ -140,16 +140,16 @@ func.func @testTwoConsecutivePadsNegativeLowPad(%arg0: tensor<10x10x10xf32>) -> // CHECK-LABEL: testTwoConsecutivePadsTwoNegativeHighPad func.func @testTwoConsecutivePadsTwoNegativeHighPad(%arg0: tensor<10x10x10xf32>) -> (tensor<9x9x9xf32>) { %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<-1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<10x10x10xf32> + %1 = "mhlo.pad"(%arg0, %0) <{edge_padding_high = dense<-1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<10x10x10xf32>, tensor) -> tensor<10x10x10xf32> %2 = mhlo.constant dense<0.000000e+00> : tensor - %3 = "mhlo.pad"(%1, %2) {edge_padding_high = dense<-1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<9x9x9xf32> + %3 = "mhlo.pad"(%1, %2) <{edge_padding_high = dense<-1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<10x10x10xf32>, tensor) -> tensor<9x9x9xf32> return %3 : tensor<9x9x9xf32> -// CHECK: %[[RES:.*]] = "mhlo.pad"(%arg0, %0) { +// CHECK: %[[RES:.*]] = "mhlo.pad"(%arg0, %0) <{ // CHECK-SAME: edge_padding_high = dense<-2> : tensor<3xi64>, // CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> -// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<9x9x9xf32> +// CHECK-SAME: }> : (tensor<10x10x10xf32>, tensor) -> tensor<9x9x9xf32> // CHECK: return %[[RES]] : tensor<9x9x9xf32> } @@ -158,16 +158,16 @@ func.func @testTwoConsecutivePadsTwoNegativeHighPad(%arg0: tensor<10x10x10xf32>) // CHECK-LABEL: testTwoConsecutivePadsPositiveNegativeHighPad func.func @testTwoConsecutivePadsPositiveNegativeHighPad(%arg0: tensor<10x10x10xf32>) -> (tensor<11x11x11xf32>) { %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> + %1 = "mhlo.pad"(%arg0, %0) <{edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> %2 = mhlo.constant dense<0.000000e+00> : tensor - %3 = "mhlo.pad"(%1, %2) {edge_padding_high = dense<-1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<12x12x12xf32>, tensor) -> tensor<11x11x11xf32> + %3 = "mhlo.pad"(%1, %2) <{edge_padding_high = dense<-1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<12x12x12xf32>, tensor) -> tensor<11x11x11xf32> return %3 : tensor<11x11x11xf32> -// CHECK: %[[RES:.*]] = "mhlo.pad"(%arg0, %0) { +// CHECK: %[[RES:.*]] = "mhlo.pad"(%arg0, %0) <{ // CHECK-SAME: edge_padding_high = dense<0> : tensor<3xi64>, // CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> -// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<11x11x11xf32> +// CHECK-SAME: }> : (tensor<10x10x10xf32>, tensor) -> tensor<11x11x11xf32> // CHECK: return %[[RES]] : tensor<11x11x11xf32> } @@ -176,22 +176,22 @@ func.func @testTwoConsecutivePadsPositiveNegativeHighPad(%arg0: tensor<10x10x10x // CHECK-LABEL: testTwoConsecutivePadsNegativePositiveHighPad func.func @testTwoConsecutivePadsNegativePositiveHighPad(%arg0: tensor<10x10x10xf32>) -> (tensor<11x11x11xf32>) { %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<-1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<10x10x10xf32> + %1 = "mhlo.pad"(%arg0, %0) <{edge_padding_high = dense<-1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<10x10x10xf32>, tensor) -> tensor<10x10x10xf32> %2 = mhlo.constant dense<0.000000e+00> : tensor - %3 = "mhlo.pad"(%1, %2) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<11x11x11xf32> + %3 = "mhlo.pad"(%1, %2) <{edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<10x10x10xf32>, tensor) -> tensor<11x11x11xf32> return %3 : tensor<11x11x11xf32> -// CHECK: "mhlo.pad"(%arg0, %0) { +// CHECK: "mhlo.pad"(%arg0, %0) <{ // CHECK-SAME: edge_padding_high = dense<-1> : tensor<3xi64>, // CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> -// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<10x10x10xf32> +// CHECK-SAME: }> : (tensor<10x10x10xf32>, tensor) -> tensor<10x10x10xf32> -// CHECK: "mhlo.pad"(%1, %0) { +// CHECK: "mhlo.pad"(%1, %0) <{ // CHECK-SAME: edge_padding_high = dense<1> : tensor<3xi64>, // CHECK-SAME: edge_padding_low = dense<0> : tensor<3xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> -// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<11x11x11xf32> +// CHECK-SAME: }> : (tensor<10x10x10xf32>, tensor) -> tensor<11x11x11xf32> } // ----- @@ -199,22 +199,22 @@ func.func @testTwoConsecutivePadsNegativePositiveHighPad(%arg0: tensor<10x10x10x // CHECK-LABEL: testTwoConsecutivePadsDifferentPadVal func.func @testTwoConsecutivePadsDifferentPadVal(%arg0: tensor<10x10x10xf32>) -> (tensor<14x14x14xf32>) { %0 = mhlo.constant dense<1.000000e+00> : tensor - %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> + %1 = "mhlo.pad"(%arg0, %0) <{edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> %2 = mhlo.constant dense<0.000000e+00> : tensor - %3 = "mhlo.pad"(%1, %2) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<12x12x12xf32>, tensor) -> tensor<14x14x14xf32> + %3 = "mhlo.pad"(%1, %2) <{edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<12x12x12xf32>, tensor) -> tensor<14x14x14xf32> return %3 : tensor<14x14x14xf32> -// CHECK: "mhlo.pad"(%arg0, %1) { +// CHECK: "mhlo.pad"(%arg0, %1) <{ // CHECK-SAME: edge_padding_high = dense<1> : tensor<3xi64>, // CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> -// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> +// CHECK-SAME: }> : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> -// CHECK: "mhlo.pad"(%2, %0) { +// CHECK: "mhlo.pad"(%2, %0) <{ // CHECK-SAME: edge_padding_high = dense<1> : tensor<3xi64>, // CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> -// CHECK-SAME: } : (tensor<12x12x12xf32>, tensor) -> tensor<14x14x14xf32> +// CHECK-SAME: }> : (tensor<12x12x12xf32>, tensor) -> tensor<14x14x14xf32> } // ----- @@ -222,23 +222,23 @@ func.func @testTwoConsecutivePadsDifferentPadVal(%arg0: tensor<10x10x10xf32>) -> // CHECK-LABEL: testTwoConsecutivePadsDifferentUsers func.func @testTwoConsecutivePadsDifferentUsers(%arg0: tensor<10x10x10xf32>) -> (tensor<13x13x13xf32>, tensor<12x12x12xf32>) { %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> + %1 = "mhlo.pad"(%arg0, %0) <{edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> %2 = mhlo.exponential %1 : tensor<12x12x12xf32> %3 = mhlo.constant dense<0.000000e+00> : tensor - %4 = "mhlo.pad"(%1, %3) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<12x12x12xf32>, tensor) -> tensor<13x13x13xf32> + %4 = "mhlo.pad"(%1, %3) <{edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<12x12x12xf32>, tensor) -> tensor<13x13x13xf32> return %4, %2 : tensor<13x13x13xf32>, tensor<12x12x12xf32> -// CHECK: "mhlo.pad"(%arg0, %0) { +// CHECK: "mhlo.pad"(%arg0, %0) <{ // CHECK-SAME: edge_padding_high = dense<1> : tensor<3xi64>, // CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> -// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> +// CHECK-SAME: }> : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> -// CHECK: "mhlo.pad"(%1, %0) { +// CHECK: "mhlo.pad"(%1, %0) <{ // CHECK-SAME: edge_padding_high = dense<1> : tensor<3xi64>, // CHECK-SAME: edge_padding_low = dense<0> : tensor<3xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> -// CHECK-SAME: } : (tensor<12x12x12xf32>, tensor) -> tensor<13x13x13xf32> +// CHECK-SAME: }> : (tensor<12x12x12xf32>, tensor) -> tensor<13x13x13xf32> } // ----- @@ -246,18 +246,18 @@ func.func @testTwoConsecutivePadsDifferentUsers(%arg0: tensor<10x10x10xf32>) -> // CHECK-LABEL: testTwoConsecutivePadsMultipleDownstreamUsers func.func @testTwoConsecutivePadsMultipleDownstreamUsers(%arg0: tensor<10x10x10xf32>) -> (tensor<13x13x13xf32>, tensor<13x13x13xf32>) { %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> + %1 = "mhlo.pad"(%arg0, %0) <{edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> %2 = mhlo.constant dense<0.000000e+00> : tensor - %3 = "mhlo.pad"(%1, %2) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<12x12x12xf32>, tensor) -> tensor<13x13x13xf32> + %3 = "mhlo.pad"(%1, %2) <{edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<12x12x12xf32>, tensor) -> tensor<13x13x13xf32> %4 = mhlo.exponential %3 : tensor<13x13x13xf32> %5 = mhlo.tanh %3 : tensor<13x13x13xf32> return %4, %5 : tensor<13x13x13xf32>, tensor<13x13x13xf32> -// CHECK: "mhlo.pad"(%arg0, %0) { +// CHECK: "mhlo.pad"(%arg0, %0) <{ // CHECK-SAME: edge_padding_high = dense<2> : tensor<3xi64>, // CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> -// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<13x13x13xf32> +// CHECK-SAME: }> : (tensor<10x10x10xf32>, tensor) -> tensor<13x13x13xf32> // CHECK: mhlo.exponential %1 : tensor<13x13x13xf32> // CHECK: mhlo.tanh %1 : tensor<13x13x13xf32> @@ -283,15 +283,15 @@ func.func @testLiftDotConcatLHSSimple(%arg0: tensor<1x1x512xf32>, %arg1: tensor< lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0] >} : (tensor<3x1x512xf32>, tensor<512x13xf32>) -> tensor<3x1x13xf32> - %r = "mhlo.concatenate"(%0, %1, %2) {dimension = 0 : i64} : (tensor<1x1x13xf32>, tensor<2x1x13xf32>, tensor<3x1x13xf32>) -> tensor<6x1x13xf32> + %r = "mhlo.concatenate"(%0, %1, %2) <{dimension = 0 : i64}> : (tensor<1x1x13xf32>, tensor<2x1x13xf32>, tensor<3x1x13xf32>) -> tensor<6x1x13xf32> func.return %r : tensor<6x1x13xf32> -// CHECK: %[[R0:.*]] = "mhlo.concatenate"(%arg0, %arg1, %arg2) {dimension = 0 : i64} : (tensor<1x1x512xf32>, tensor<2x1x512xf32>, tensor<3x1x512xf32>) -> tensor<6x1x512xf32> -// CHECK: %[[R1:.*]] = "mhlo.dot_general"(%[[R0]], %arg3) { +// CHECK: %[[R0:.*]] = "mhlo.concatenate"(%arg0, %arg1, %arg2) <{dimension = 0 : i64}> : (tensor<1x1x512xf32>, tensor<2x1x512xf32>, tensor<3x1x512xf32>) -> tensor<6x1x512xf32> +// CHECK: %[[R1:.*]] = "mhlo.dot_general"(%[[R0]], %arg3) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_contracting_dimensions = [2], // CHECK-SAME: rhs_contracting_dimensions = [0] -// CHECK-SAME: >} : (tensor<6x1x512xf32>, tensor<512x13xf32>) -> tensor<6x1x13xf32> +// CHECK-SAME: >}> : (tensor<6x1x512xf32>, tensor<512x13xf32>) -> tensor<6x1x13xf32> // CHECK: return %[[R1]] : tensor<6x1x13xf32> } @@ -313,17 +313,17 @@ func.func @testLiftDotConcatLHSComplex(%arg0: tensor<1x9x2x3x8x4x10xf32>, %arg1: lhs_contracting_dimensions = [4, 1, 6], rhs_contracting_dimensions = [6, 0, 4] >} : (tensor<1x9x2x3x8x100x10xf32>, tensor<9x2x1x5x10x5x8x7xf32>) -> tensor<1x2x3x100x5x5x7xf32> - %r = "mhlo.concatenate"(%0, %1) {dimension = 3 : i64} : (tensor<1x2x3x4x5x5x7xf32>, tensor<1x2x3x100x5x5x7xf32>) -> tensor<1x2x3x104x5x5x7xf32> + %r = "mhlo.concatenate"(%0, %1) <{dimension = 3 : i64}> : (tensor<1x2x3x4x5x5x7xf32>, tensor<1x2x3x100x5x5x7xf32>) -> tensor<1x2x3x104x5x5x7xf32> func.return %r : tensor<1x2x3x104x5x5x7xf32> -// CHECK: %[[R0:.*]] = "mhlo.concatenate"(%arg0, %arg1) {dimension = 5 : i64} : (tensor<1x9x2x3x8x4x10xf32>, tensor<1x9x2x3x8x100x10xf32>) -> tensor<1x9x2x3x8x104x10xf32> -// CHECK: %[[R1:.*]] = "mhlo.dot_general"(%[[R0]], %arg2) { +// CHECK: %[[R0:.*]] = "mhlo.concatenate"(%arg0, %arg1) <{dimension = 5 : i64}> : (tensor<1x9x2x3x8x4x10xf32>, tensor<1x9x2x3x8x100x10xf32>) -> tensor<1x9x2x3x8x104x10xf32> +// CHECK: %[[R1:.*]] = "mhlo.dot_general"(%[[R0]], %arg2) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_batching_dimensions = [0, 2], // CHECK-SAME: rhs_batching_dimensions = [2, 1], // CHECK-SAME: lhs_contracting_dimensions = [4, 1, 6], // CHECK-SAME: rhs_contracting_dimensions = [6, 0, 4] -// CHECK-SAME: >} : (tensor<1x9x2x3x8x104x10xf32>, tensor<9x2x1x5x10x5x8x7xf32>) -> tensor<1x2x3x104x5x5x7xf32> +// CHECK-SAME: >}> : (tensor<1x9x2x3x8x104x10xf32>, tensor<9x2x1x5x10x5x8x7xf32>) -> tensor<1x2x3x104x5x5x7xf32> // CHECK: return %[[R1]] : tensor<1x2x3x104x5x5x7xf32> } @@ -359,18 +359,18 @@ func.func @testLiftDotConcatLHSAndRHS(%arg0: tensor<1x72x128xf32>, %arg1: tensor lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1] >} : (tensor<1x72x128xf32>, tensor<1x128x72xf32>) -> tensor<1x72x72xf32> - %4 = "mhlo.concatenate"(%0, %1, %2, %3) {dimension = 0 : i64} : (tensor<1x72x72xf32>, tensor<1x72x72xf32>, tensor<1x72x72xf32>, tensor<1x72x72xf32>) -> tensor<4x72x72xf32> + %4 = "mhlo.concatenate"(%0, %1, %2, %3) <{dimension = 0 : i64}> : (tensor<1x72x72xf32>, tensor<1x72x72xf32>, tensor<1x72x72xf32>, tensor<1x72x72xf32>) -> tensor<4x72x72xf32> func.return %4 : tensor<4x72x72xf32> -// CHECK: %[[R0:.*]] = "mhlo.concatenate"(%arg0, %arg2, %arg4, %arg6) {dimension = 0 : i64} : (tensor<1x72x128xf32>, tensor<1x72x128xf32>, tensor<1x72x128xf32>, tensor<1x72x128xf32>) -> tensor<4x72x128xf32> -// CHECK: %[[R1:.*]] = "mhlo.concatenate"(%arg1, %arg3, %arg5, %arg7) {dimension = 0 : i64} : (tensor<1x128x72xf32>, tensor<1x128x72xf32>, tensor<1x128x72xf32>, tensor<1x128x72xf32>) -> tensor<4x128x72xf32> -// CHECK: %[[R2:.*]] = "mhlo.dot_general"(%[[R0]], %[[R1]]) { +// CHECK: %[[R0:.*]] = "mhlo.concatenate"(%arg0, %arg2, %arg4, %arg6) <{dimension = 0 : i64}> : (tensor<1x72x128xf32>, tensor<1x72x128xf32>, tensor<1x72x128xf32>, tensor<1x72x128xf32>) -> tensor<4x72x128xf32> +// CHECK: %[[R1:.*]] = "mhlo.concatenate"(%arg1, %arg3, %arg5, %arg7) <{dimension = 0 : i64}> : (tensor<1x128x72xf32>, tensor<1x128x72xf32>, tensor<1x128x72xf32>, tensor<1x128x72xf32>) -> tensor<4x128x72xf32> +// CHECK: %[[R2:.*]] = "mhlo.dot_general"(%[[R0]], %[[R1]]) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_batching_dimensions = [0], // CHECK-SAME: rhs_batching_dimensions = [0], // CHECK-SAME: lhs_contracting_dimensions = [2], // CHECK-SAME: rhs_contracting_dimensions = [1] -// CHECK-SAME: >} : (tensor<4x72x128xf32>, tensor<4x128x72xf32>) -> tensor<4x72x72xf32> +// CHECK-SAME: >}> : (tensor<4x72x128xf32>, tensor<4x128x72xf32>) -> tensor<4x72x72xf32> // CHECK: return %[[R2]] : tensor<4x72x72xf32> } @@ -378,10 +378,10 @@ func.func @testLiftDotConcatLHSAndRHS(%arg0: tensor<1x72x128xf32>, %arg1: tensor // CHECK-LABEL: testSliceConcat func.func @testSliceConcat(%arg0: tensor<3x1x512xf32>) -> tensor<3x1x512xf32> { - %0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 1, 512]> : tensor<3xi64>, start_indices = dense<[0, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<3x1x512xf32>) -> tensor<1x1x512xf32> - %1 = "mhlo.slice"(%arg0) {limit_indices = dense<[2, 1, 512]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<3x1x512xf32>) -> tensor<1x1x512xf32> - %2 = "mhlo.slice"(%arg0) {limit_indices = dense<[3, 1, 512]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<3x1x512xf32>) -> tensor<1x1x512xf32> - %r = "mhlo.concatenate"(%0, %1, %2) {dimension = 0 : i64} : (tensor<1x1x512xf32>, tensor<1x1x512xf32>, tensor<1x1x512xf32>) -> tensor<3x1x512xf32> + %0 = "mhlo.slice"(%arg0) <{limit_indices = dense<[1, 1, 512]> : tensor<3xi64>, start_indices = dense<[0, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<3x1x512xf32>) -> tensor<1x1x512xf32> + %1 = "mhlo.slice"(%arg0) <{limit_indices = dense<[2, 1, 512]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<3x1x512xf32>) -> tensor<1x1x512xf32> + %2 = "mhlo.slice"(%arg0) <{limit_indices = dense<[3, 1, 512]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<3x1x512xf32>) -> tensor<1x1x512xf32> + %r = "mhlo.concatenate"(%0, %1, %2) <{dimension = 0 : i64}> : (tensor<1x1x512xf32>, tensor<1x1x512xf32>, tensor<1x1x512xf32>) -> tensor<3x1x512xf32> func.return %r : tensor<3x1x512xf32> // CHECK: return %arg0 : tensor<3x1x512xf32> @@ -399,12 +399,12 @@ func.func @testConvertReshapeDotRhsToBatchedDot(%arg0: tensor<1x72x72xf32>, %arg >} : (tensor<1x72x72xf32>, tensor<72x128xf32>) -> tensor<1x72x128xf32> func.return %1 : tensor<1x72x128xf32> -// CHECK: %[[R:.*]] = "mhlo.dot_general"(%arg0, %arg1) { +// CHECK: %[[R:.*]] = "mhlo.dot_general"(%arg0, %arg1) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_batching_dimensions = [0], // CHECK-SAME: rhs_batching_dimensions = [0], // CHECK-SAME: lhs_contracting_dimensions = [2], // CHECK-SAME: rhs_contracting_dimensions = [1] -// CHECK-SAME: >} : (tensor<1x72x72xf32>, tensor<1x72x128xf32>) -> tensor<1x72x128xf32> +// CHECK-SAME: >}> : (tensor<1x72x72xf32>, tensor<1x72x128xf32>) -> tensor<1x72x128xf32> // CHECK: return %[[R]] : tensor<1x72x128xf32> } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir index d5384b4c96a1f3..479147ded9bb5d 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir @@ -14,9 +14,9 @@ func.func @tfInplaceUpdate(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> { //CHECK: module attributes //CHECK-SAME: keep_stablehlo_constant = "true" //CHECK-NEXT: func.func @main(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "vhlo.dynamic_update_slice_v1"}} { -//CHECK-DAG: %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x2xf32> -//CHECK-DAG: %1 = stablehlo.constant dense<1> : tensor -//CHECK-DAG: %2 = stablehlo.constant dense<0> : tensor -//CHECK-NEXT: %3 = stablehlo.dynamic_update_slice %arg0, %0, %1, %2, %2 : (tensor<2x1x2xf32>, tensor<1x1x2xf32>, tensor, tensor, tensor) -> tensor<2x1x2xf32> -//CHECK-NEXT: return %3 : tensor<2x1x2xf32> +//CHECK-DAG: %[[c0:.+]] = stablehlo.constant dense<2.000000e+00> : tensor<1x1x2xf32> +//CHECK-DAG: %[[c1:.+]] = stablehlo.constant dense<1> : tensor +//CHECK-DAG: %[[c2:.+]] = stablehlo.constant dense<0> : tensor +//CHECK-NEXT: %[[c3:.+]] = stablehlo.dynamic_update_slice %arg0, %[[c0]], %[[c1]], %[[c2]], %[[c2]] : (tensor<2x1x2xf32>, tensor<1x1x2xf32>, tensor, tensor, tensor) -> tensor<2x1x2xf32> +//CHECK-NEXT: return %[[c3]] : tensor<2x1x2xf32> //CHECK-NEXT: } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-tf-quantize.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-tf-quantize.mlir index d23de7ce50cef9..0b0988363a760a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-tf-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-tf-quantize.mlir @@ -13,10 +13,10 @@ func.func @tfInplaceUpdate(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> { //CHECK: module { //CHECK-NEXT: func.func @main(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> { -//CHECK-DAG: %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x2xf32> -//CHECK-DAG: %1 = stablehlo.constant dense<1> : tensor -//CHECK-DAG: %2 = stablehlo.constant dense<0> : tensor -//CHECK-NEXT: %3 = stablehlo.dynamic_update_slice %arg0, %0, %1, %2, %2 : (tensor<2x1x2xf32>, tensor<1x1x2xf32>, tensor, tensor, tensor) -> tensor<2x1x2xf32> -//CHECK-NEXT: return %3 : tensor<2x1x2xf32> +//CHECK-DAG: %[[c0:.+]] = stablehlo.constant dense<2.000000e+00> : tensor<1x1x2xf32> +//CHECK-DAG: %[[c1:.+]] = stablehlo.constant dense<1> : tensor +//CHECK-DAG: %[[c2:.+]] = stablehlo.constant dense<0> : tensor +//CHECK-NEXT: %[[c3:.+]] = stablehlo.dynamic_update_slice %arg0, %[[c0]], %[[c1]], %[[c2]], %[[c2]] : (tensor<2x1x2xf32>, tensor<1x1x2xf32>, tensor, tensor, tensor) -> tensor<2x1x2xf32> +//CHECK-NEXT: return %[[c3:.+]] : tensor<2x1x2xf32> //CHECK-NEXT: } //CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir index 4dcfa9c3cbacd7..9be635a44268f6 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir @@ -11,7 +11,7 @@ func.func @main(%arg0: tensor<5x7xf32>) -> tensor<5x7xf32> { // - transpose // func.func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32> + %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<2x3xf32>) -> tensor<3x2xf32> func.return %0 : tensor<3x2xf32> // CHECK-LABEL: transpose_2d @@ -22,7 +22,7 @@ func.func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { } func.func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> + %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[2, 1, 0]> : tensor<3xi64>}> : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> func.return %0 : tensor<3x2x1xf32> // CHECK-LABEL: transpose_3d @@ -33,7 +33,7 @@ func.func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { } func.func @transpose_dynamic_2d(%arg0: tensor) -> tensor<4x?xf32> { - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor) -> tensor<4x?xf32> + %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor) -> tensor<4x?xf32> func.return %0 : tensor<4x?xf32> // CHECK-LABEL: transpose_dynamic_2d @@ -272,8 +272,8 @@ func.return %0 : tensor<4x4x256xf32> func.func @convert_argmax(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32xi32>) { %0 = mhlo.constant dense<0xFF800000> : tensor %1 = mhlo.constant dense<0> : tensor - %2 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<256xi32> - %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xi32>) -> tensor<4x32x256xi32> + %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32> + %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32> %4:2 = "mhlo.reduce"(%arg0, %3, %0, %1) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %7 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor @@ -291,8 +291,8 @@ func.func @convert_argmax(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, ten // CHECK: %0 = mhlo.constant dense<0xFF800000> : tensor // CHECK-DAG: %1 = mhlo.constant dense<0> : tensor - // CHECK: %2 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<256xi32> - // CHECK: %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xi32>) -> tensor<4x32x256xi32> + // CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32> + // CHECK: %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32> // CHECK: %cst = arith.constant dense<2> : tensor<1xi32> // CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = false} : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xf32> // CHECK: %5 = "tfl.arg_max"(%arg0, %cst) : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xi32> @@ -359,7 +359,7 @@ func.func @convert_argmax_constant_non_z_axis(%arg0: tensor<4x4xf32>) -> (tensor // CHECK-LABEL: func.func @convert_argmax_bool func.func @convert_argmax_bool(%arg0: tensor<2xi1>) -> tensor { - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2xi32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> %1 = mhlo.constant dense : tensor %2 = mhlo.constant dense<0> : tensor %3:2 = mhlo.reduce(%arg0 init: %1), (%0 init: %2) across dimensions = [0] : (tensor<2xi1>, tensor<2xi32>, tensor, tensor) -> (tensor, tensor) @@ -375,7 +375,7 @@ func.func @convert_argmax_bool(%arg0: tensor<2xi1>) -> tensor { } return %3#1 : tensor - // CHECK: %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2xi32> + // CHECK: %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> // CHECK-DAG: %1 = mhlo.constant dense : tensor // CHECK: %2 = mhlo.constant dense<0> : tensor // CHECK: %cst = arith.constant dense<0> : tensor<1xi32> @@ -388,8 +388,8 @@ func.func @convert_argmax_bool(%arg0: tensor<2xi1>) -> tensor { func.func @convert_argmin(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32xi32>) { %0 = mhlo.constant dense<0x7F800000> : tensor %1 = mhlo.constant dense<0> : tensor - %2 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<256xi32> - %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xi32>) -> tensor<4x32x256xi32> + %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32> + %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32> %4:2 = "mhlo.reduce"(%arg0, %3, %0, %1) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %7 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor @@ -407,8 +407,8 @@ func.func @convert_argmin(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, ten // CHECK-DAG: %0 = mhlo.constant dense<0x7F800000> : tensor // CHECK: %1 = mhlo.constant dense<0> : tensor - // CHECK: %2 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<256xi32> - // CHECK: %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xi32>) -> tensor<4x32x256xi32> + // CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32> + // CHECK: %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32> // CHECK: %cst = arith.constant dense<2> : tensor<1xi32> // CHECK: %4 = "tfl.reduce_min"(%arg0, %cst) {keep_dims = false} : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xf32> // CHECK: %5 = "tfl.arg_min"(%arg0, %cst) : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xi32> @@ -418,7 +418,7 @@ func.func @convert_argmin(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, ten // CHECK-LABEL: func @convert_argmin_i16 func.func @convert_argmin_i16(%arg0: tensor<2xi16>) -> (tensor, tensor) { %0 = mhlo.constant dense : tensor - %1 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2xi32> + %1 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> %2 = mhlo.constant dense<32767> : tensor %3 = mhlo.constant dense<0> : tensor %4:2 = "mhlo.reduce"(%arg0, %1, %2, %3) ({ @@ -436,7 +436,7 @@ func.func @convert_argmin_i16(%arg0: tensor<2xi16>) -> (tensor, tensor func.return %4#0, %4#1 : tensor, tensor // CHECK: %0 = mhlo.constant dense : tensor - // CHECK: %1 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2xi32> + // CHECK: %1 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> // CHECK-DAG: %2 = mhlo.constant dense<32767> : tensor // CHECK: %3 = mhlo.constant dense<0> : tensor // CHECK: %cst = arith.constant dense<0> : tensor<1xi32> @@ -477,7 +477,7 @@ func.func @convert_argmin_constant(%arg0: tensor<2x2x4xf32>) -> (tensor<2x2xf32> // CHECK-LABEL: func.func @convert_argmin_bool func.func @convert_argmin_bool(%arg0: tensor<2xi1>) -> tensor { - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2xi32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> %1 = mhlo.constant dense : tensor %2 = mhlo.constant dense<0> : tensor %3:2 = mhlo.reduce(%arg0 init: %1), (%0 init: %2) across dimensions = [0] : (tensor<2xi1>, tensor<2xi32>, tensor, tensor) -> (tensor, tensor) @@ -493,7 +493,7 @@ func.func @convert_argmin_bool(%arg0: tensor<2xi1>) -> tensor { } return %3#1 : tensor - // CHECK: %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2xi32> + // CHECK: %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> // CHECK-DAG: %1 = mhlo.constant dense : tensor // CHECK: %2 = mhlo.constant dense<0> : tensor // CHECK: %cst = arith.constant dense<0> : tensor<1xi32> @@ -506,7 +506,7 @@ func.func @convert_argmin_bool(%arg0: tensor<2xi1>) -> tensor { func.func @convert_argmax_with_reshaped_iota(%arg0: tensor<1x32x1xf32>) -> (tensor<1x1xf32>, tensor<1x1xi32>) { %0 = mhlo.constant dense<0xFF800000> : tensor %1 = mhlo.constant dense<0> : tensor - %2 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<32xi32> + %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<32xi32> %3 = "mhlo.reshape"(%2) : (tensor<32xi32>) -> tensor<1x32x1xi32> %4:2 = mhlo.reduce(%arg0 init: %0), (%3 init: %1) across dimensions = [1] : (tensor<1x32x1xf32>, tensor<1x32x1xi32>, tensor, tensor) -> (tensor<1x1xf32>, tensor<1x1xi32>) reducer(%arg1: tensor, %arg3: tensor) (%arg2: tensor, %arg4: tensor) { @@ -525,7 +525,7 @@ func.func @convert_argmax_with_reshaped_iota(%arg0: tensor<1x32x1xf32>) -> (tens // CHECK-DAG: %0 = mhlo.constant dense<0xFF800000> : tensor // CHECK: %1 = mhlo.constant dense<0> : tensor - // CHECK: %2 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<32xi32> + // CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<32xi32> // CHECK: %3 = mhlo.reshape %2 : (tensor<32xi32>) -> tensor<1x32x1xi32> // CHECK: %cst = arith.constant dense<1> : tensor<1xi32> // CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = false} : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xf32> @@ -537,7 +537,7 @@ func.func @convert_argmax_with_reshaped_iota(%arg0: tensor<1x32x1xf32>) -> (tens func.func @convert_pytorch_argmax(%arg0: tensor<1x9xi32>) -> tensor<1xi32> { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<-2147483648> : tensor - %2 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<9xi32> + %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<9xi32> %3 = mhlo.reshape %2 : (tensor<9xi32>) -> tensor<1x9xi32> %4:2 = mhlo.reduce(%arg0 init: %1), (%3 init: %0) across dimensions = [1] : (tensor<1x9xi32>, tensor<1x9xi32>, tensor, tensor) -> (tensor<1xi32>, tensor<1xi32>) reducer(%arg1: tensor, %arg3: tensor) (%arg2: tensor, %arg4: tensor) { @@ -553,7 +553,7 @@ func.func @convert_pytorch_argmax(%arg0: tensor<1x9xi32>) -> tensor<1xi32> { // CHECK: %0 = mhlo.constant dense<0> : tensor // CHECK-DAG: %1 = mhlo.constant dense<-2147483648> : tensor - // CHECK: %2 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<9xi32> + // CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<9xi32> // CHECK: %3 = mhlo.reshape %2 : (tensor<9xi32>) -> tensor<1x9xi32> // CHECK: %cst = arith.constant dense<1> : tensor<1xi32> // CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = false} : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/unfold_splat_constant_pass.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/unfold_splat_constant_pass.mlir index fbad58fca6e940..fab612eef01a4f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/unfold_splat_constant_pass.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/unfold_splat_constant_pass.mlir @@ -6,7 +6,7 @@ func.func @unfold_splat_constant_float() -> tensor<1x750xf32> { func.return %cst : tensor<1x750xf32> // CHECK-DAG: %0 = mhlo.constant dense<7.680000e+02> : tensor - // CHECK: %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x750xf32> + // CHECK: %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<1x750xf32> // CHECK: return %1 : tensor<1x750xf32> } @@ -16,7 +16,7 @@ func.func @unfold_splat_constant_integer() -> tensor<1x750xi32> { func.return %cst : tensor<1x750xi32> // CHECK-DAG: %0 = mhlo.constant dense<1> : tensor - // CHECK: %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x750xi32> + // CHECK: %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<1x750xi32> // CHECK: return %1 : tensor<1x750xi32> } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir index 6f2771756cbac5..70a196f2af44c9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir @@ -16,9 +16,9 @@ func.func @batchNormInference_2D_inner_features( // CHECK-DAG: %[[MULTIPLIER:.+]] = mhlo.multiply %[[VARIANCE_EPS_RSQRT]], %[[SCALE]] : tensor<256xf32> // CHECK-DAG: %[[MUL_MEAN:.+]] = mhlo.multiply %[[MULTIPLIER]], %[[MEAN]] : tensor<256xf32> // CHECK-DAG: %[[RHS:.+]] = mhlo.subtract %[[OFFSET]], %[[MUL_MEAN]] : tensor<256xf32> - // CHECK-DAG: %[[MULTIPLIER_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MULTIPLIER]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[MULTIPLIER_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MULTIPLIER]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.multiply %[[X]], %[[MULTIPLIER_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[RHS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[RHS]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[RHS_BCAST]] : tensor<4x256xf32> %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} : @@ -44,8 +44,8 @@ func.func @batchNormInference_4D_middle_features( // CHECK-DAG: %[[MULTIPLIER:.+]] = mhlo.multiply %[[VARIANCE_EPS_RSQRT]], %[[SCALE]] : tensor<256xf32> // CHECK-DAG: %[[MUL_MEAN:.+]] = mhlo.multiply %[[MULTIPLIER]], %[[MEAN]] : tensor<256xf32> // CHECK-DAG: %[[RHS:.+]] = mhlo.subtract %[[OFFSET]], %[[MUL_MEAN]] : tensor<256xf32> - // CHECK-DAG: %[[MULTIPLIER_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MULTIPLIER]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> - // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[RHS]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> + // CHECK-DAG: %[[MULTIPLIER_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MULTIPLIER]]) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<3x4x256x6xf32> + // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[RHS]]) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<3x4x256x6xf32> %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} : (tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, @@ -66,16 +66,16 @@ func.func @batchNormInference_dynamic_shape( -> tensor { // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor // CHECK-DAG: %[[VAR_SHAPE:.+]] = shape.shape_of %[[VARIANCE]] : tensor -> tensor<1xindex> - // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[VAR_SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[VAR_SHAPE]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor // CHECK-DAG: %[[R_STDDEV:.+]] = mhlo.rsqrt %[[VARIANCE_EPS]] : tensor // CHECK-DAG: %[[MULTIPLIER:.+]] = mhlo.multiply %[[R_STDDEV]], %[[SCALE]] : tensor // CHECK-DAG: %[[MUL_MEAN:.+]] = mhlo.multiply %[[MULTIPLIER]], %[[MEAN]] : tensor // CHECK-DAG: %[[RHS:.+]] = mhlo.subtract %[[OFFSET]], %[[MUL_MEAN]] : tensor // CHECK-DAG: %[[X_SHAPE:.+]] = shape.shape_of %[[X]] : tensor -> tensor<4xindex> - // CHECK-DAG: %[[MULTIPLIER_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MULTIPLIER]], %[[X_SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[MULTIPLIER_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MULTIPLIER]], %[[X_SHAPE]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.multiply %[[X]], %[[MULTIPLIER_BCAST]] : tensor - // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[RHS]], %[[X_SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[RHS]], %[[X_SHAPE]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[RHS_BCAST]] : tensor %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) {epsilon = 0.001 : f32, feature_index = 1 : i64} : @@ -136,7 +136,7 @@ func.func @batchNormTraining_4D_middle_features( // CHECK-DAG: %[[X_SHAPE:.+]] = shape.shape_of %[[X]] : tensor<3x4x256x6xf32> -> tensor<4xindex> // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<256xf32> // CHECK-DAG: %[[MEAN:.+]] = "tf.Mean"(%arg0, %[[CST_AXIS]]) <{keep_dims = false}> : (tensor<3x4x256x6xf32>, tensor<3xi32>) -> tensor<256xf32> - // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[X_SHAPE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>, tensor<4xindex>) -> tensor<3x4x256x6xf32> + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[X_SHAPE]]) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xf32>, tensor<4xindex>) -> tensor<3x4x256x6xf32> // CHECK-DAG: %[[SQ_DIFF:.+]] = "tf.SquaredDifference"(%arg0, %[[MEAN_BCAST]]) : (tensor<3x4x256x6xf32>, tensor<3x4x256x6xf32>) -> tensor<3x4x256x6xf32> // CHECK-DAG: %[[VARIANCE:.+]] = "tf.Mean"(%[[SQ_DIFF]], %[[CST_AXIS]]) <{keep_dims = false}> : (tensor<3x4x256x6xf32>, tensor<3xi32>) -> tensor<256xf32> // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS]] : tensor<256xf32> @@ -144,9 +144,9 @@ func.func @batchNormTraining_4D_middle_features( // CHECK-DAG: %[[MULTIPLIER:.+]] = mhlo.multiply %[[VARIANCE_EPS_RSQRT]], %[[SCALE]] : tensor<256xf32> // CHECK-DAG: %[[MUL_MEAN:.+]] = mhlo.multiply %[[MULTIPLIER]], %[[MEAN]] : tensor<256xf32> // CHECK-DAG: %[[RHS:.+]] = mhlo.subtract %[[OFFSET]], %[[MUL_MEAN]] : tensor<256xf32> - // CHECK-DAG: %[[MULTIPLIER_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MULTIPLIER]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> + // CHECK-DAG: %[[MULTIPLIER_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MULTIPLIER]]) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<3x4x256x6xf32> // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.multiply %[[X]], %[[MULTIPLIER_BCAST]] : tensor<3x4x256x6xf32> - // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[RHS]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> + // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[RHS]]) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<3x4x256x6xf32> // CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[RHS_BCAST]] : tensor<3x4x256x6xf32> %0:3 = "mhlo.batch_norm_training"(%x, %scale, %offset) {epsilon = 1.0 : f32, feature_index = 2 : i64} : diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.cc new file mode 100644 index 00000000000000..801c8775682cbd --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.cc @@ -0,0 +1,154 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep +#include "tensorflow/core/framework/kernel_shape_util.h" +#include "tensorflow/core/util/padding.h" + +namespace mlir { +namespace odml { + +DenseIntElementsAttr GetPaddingArrayAttr(Builder& builder, Operation* old_op) { + mhlo::CompositeOp composite_op = llvm::dyn_cast(old_op); + auto composite_attrs = composite_op.getCompositeAttributes(); + std::vector padding_vec; + GetI32VectorFromDenseI64CompositeAttr(composite_attrs, "padding", + &padding_vec); + + std::vector result_padding_conf(8, 0); // NHWC + result_padding_conf[2] = result_padding_conf[3] = padding_vec[0]; + result_padding_conf[4] = result_padding_conf[5] = padding_vec[1]; + + return DenseIntElementsAttr::get( + RankedTensorType::get({4, 2}, builder.getI32Type()), result_padding_conf); +} + +ShapedType GetPaddedType(Operation* old_op) { + auto input_type = old_op->getOperand(0).getType().cast(); + auto input_shape = input_type.getShape(); // NCHW + int64_t batch_size = input_shape[0]; + int64_t channel_size = input_shape[1]; + int64_t height = input_shape[2]; + int64_t width = input_shape[3]; + + DenseIntElementsAttr padding_attr; + mhlo::CompositeOp composite_op = llvm::dyn_cast(old_op); + auto composite_attributes = composite_op.getCompositeAttributes(); + EnsureAttribute(composite_attributes, "padding", + &padding_attr); + std::vector padding_values(padding_attr.getValues().begin(), + padding_attr.getValues().end()); + int64_t padding_height = padding_values[0]; + int64_t padding_width = padding_values[1]; + + std::array output_shape = { + batch_size, height + 2 * padding_height, width + 2 * padding_width, + channel_size}; // NHWC + return RankedTensorType::get(output_shape, input_type.getElementType()); +} + +// Checks if the provided configuration can be supported by the tensorflow +// "SAME" padding configuration. +static bool IsSamePadding(const std::vector& spatial_dim_sizes, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& padding_array) { + for (int dim : llvm::seq(0, spatial_dim_sizes.size())) { + int64_t discard; + int64_t pad_low_ignore; + int64_t pad_high_ignore; + absl::Status status = tensorflow::GetWindowedOutputSizeVerbose( + spatial_dim_sizes[dim], kernel_size[dim], 1, strides[dim], + tensorflow::Padding::SAME, &discard, &pad_low_ignore, &pad_high_ignore); + if (!status.ok()) { + return false; + } + if (padding_array[dim] != pad_low_ignore || + padding_array[dim] != pad_high_ignore) { + return false; + } + } + + return true; +} + +enum class PaddingType { kValid, kSame, kCustom }; + +static PaddingType GetPaddingType(const std::vector& spatial_dim_sizes, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& padding_array) { + if (std::all_of(padding_array.begin(), padding_array.end(), + [](int32_t padding_value) { return padding_value == 0; })) { + return PaddingType::kValid; + } + if (IsSamePadding(spatial_dim_sizes, kernel_size, strides, padding_array)) { + return PaddingType::kSame; + } + return PaddingType::kCustom; +} + +StringAttr GetPaddingStringAttr(Builder& builder, Operation* old_op) { + mhlo::CompositeOp composite_op = llvm::dyn_cast(old_op); + auto composite_attrs = composite_op.getCompositeAttributes(); + + auto operand_shape = + composite_op.getOperand(0).getType().cast().getShape(); + // NC(H)(W) + std::vector spatial_dim_sizes = { + static_cast(operand_shape[2]), + static_cast(operand_shape[3])}; + + std::vector padding_vec, kernel_size_vec, strides_vec; + GetI32VectorFromDenseI64CompositeAttr(composite_attrs, "kernel_size", + &kernel_size_vec); + GetI32VectorFromDenseI64CompositeAttr(composite_attrs, "stride", + &strides_vec); + GetI32VectorFromDenseI64CompositeAttr(composite_attrs, "padding", + &padding_vec); + PaddingType padding_type = GetPaddingType(spatial_dim_sizes, kernel_size_vec, + strides_vec, padding_vec); + + switch (padding_type) { + case PaddingType::kValid: + return builder.getStringAttr("VALID"); + case PaddingType::kSame: + return builder.getStringAttr("SAME"); + case PaddingType::kCustom: + return builder.getStringAttr("CUSTOM"); + } +} + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.h new file mode 100644 index 00000000000000..4224f2d6c8ae10 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_AVG_POOL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_AVG_POOL_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep + +namespace mlir { +namespace odml { + +// Given a Composite op that wraps a core.aten.avg_pool2d, returns the padding +// configuration required for the `tfl.pad` if the padding part of the op is +// to be done before average pooling. +DenseIntElementsAttr GetPaddingArrayAttr(Builder& builder, Operation* old_op); + +// Given a Composite op that wraps a core.aten.avg_pool2d, and assuming that +// the padding part is extracted into a tfl.pad op prior to a +// tfl.average_pool_2d, this function finds the return type of the needed +// tfl.pad . +ShapedType GetPaddedType(Operation* old_op); + +// Given a Composite op that wraps a core.aten.avg_pool2d, finds the padding +// attribute to be passed to the a tfl.average_pool_2d that can fully replace +// this composite (here, padding is done directly by the tfl.average_pool_2d as +// opposed to being extracted into a separate tfl.pad). +StringAttr GetPaddingStringAttr(Builder& builder, Operation* old_op); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_AVG_POOL_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool_patterns.td new file mode 100644 index 00000000000000..607b8f520ba6f9 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool_patterns.td @@ -0,0 +1,91 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/PatternBase.td" +include "tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td" + + +// See the function doc in the header file. +def GetPaddedType : NativeCodeCall< + "GetPaddedType((*$0.begin()).getDefiningOp())">; + +// See the function doc in the header file. +def GetPadding: + NativeCodeCall<"GetPaddingStringAttr($_builder, (*$0.begin()).getDefiningOp())">; + +// Returns true if the provided padding in the composite op can *not* be +// satisfied by SAME or VALID tensorflow padding. +def HasCustomPadding: + Constraint>; + +// Returns true if the provided padding in the composite op can be satisfied +// by SAME or VALID tensorflow padding. +def HasSameOrValidPadding: Constraint>; + +// See the function doc in the header file. +def GetPaddingArrayAttr: NativeCodeCall<"GetPaddingArrayAttr($_builder, (*$0.begin()).getDefiningOp())">; + +// Replaces an ate.avg_pool2d with a (T -> tfl.average_pool_2d -> T). +// Constraints are added on the attributes of the aten.avg_pool2d to ensure only +// ops that match the behaviour of tfl.average_pool_2d are directly lowered. +def LegalizeAvgPool2dComposite: Pat< + (MHLO_CompositeOp:$old_val + (variadic $a_input), + ConstantStrAttr, $attrs, $_, $_), + (TFL_TransposeOp + (TFL_AveragePool2DOp + /*input*/ (TFL_TransposeOp $a_input, + (Arith_ConstantOp + ConstantAttr,"{0, 2, 3, 1}">)), + /*filter_height*/(GetI32At<0> (GetAsVectorAttr<"kernel_size"> $attrs)), + /*filter_width*/(GetI32At<1> (GetAsVectorAttr<"kernel_size"> $attrs)), + /*padding*/(GetPadding $old_val), + /*stride_h*/(GetI32At<0> (GetAsVectorAttr<"stride"> $attrs)), + /*stride_w*/(GetI32At<1> (GetAsVectorAttr<"stride"> $attrs)), + /*fused_activation_function*/TFL_AF_None, + (returnType (GetNhwcReturnTypeFromNchw $old_val))), + (Arith_ConstantOp + ConstantAttr,"{0, 3, 1, 2}">)), + [(IsBoolCompositeAttribute<"ceil_mode", "false"> $attrs), + (IsBoolCompositeAttribute<"count_include_pad", "false"> $attrs), + (IsStrCompositeAttribute<"divisor_override", "py_None"> $attrs), + (HasSameOrValidPadding $old_val)]>; + +// Replaces an ate.avg_pool2d with (T -> tfl.pad -> tfl.average_pool_2d -> T). +def LegalizeAvgPool2dWithPadComposite: Pat< + (MHLO_CompositeOp:$old_val + (variadic $a_input), + ConstantStrAttr, $attrs, $_, $_), + (TFL_TransposeOp + (TFL_AveragePool2DOp:$padded_value + /*input*/ (TFL_PadOp + (TFL_TransposeOp $a_input, + (Arith_ConstantOp + ConstantAttr,"{0, 2, 3, 1}">)), + (Arith_ConstantOp + (GetPaddingArrayAttr $old_val)), + (returnType (GetPaddedType $old_val))), + /*filter_height*/(GetI32At<0> (GetAsVectorAttr<"kernel_size"> $attrs)), + /*filter_width*/(GetI32At<1> (GetAsVectorAttr<"kernel_size"> $attrs)), + /*padding*/TFL_PAD_Valid, + /*stride_h*/(GetI32At<0> (GetAsVectorAttr<"stride"> $attrs)), + /*stride_w*/(GetI32At<1> (GetAsVectorAttr<"stride"> $attrs)), + /*fused_activation_function*/TFL_AF_None, + (returnType (GetNhwcReturnTypeFromNchw $old_val))), + (Arith_ConstantOp + ConstantAttr,"{0, 3, 1, 2}">)), + [(IsBoolCompositeAttribute<"ceil_mode", "false"> $attrs), + (IsStrCompositeAttribute<"divisor_override", "py_None"> $attrs), + (IsBoolCompositeAttribute<"count_include_pad", "true"> $attrs)]>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.cc index 0dc354f998d246..11e2272a145f0b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project @@ -25,6 +26,8 @@ limitations under the License. #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep @@ -57,6 +60,7 @@ void CompositeLoweringPass::runOnOperation() { ConversionTarget target(context); target.addLegalDialect(); + target.addLegalDialect(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td index 1b62b6fcc4aeae..829cf2fbaf16a4 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td @@ -14,15 +14,38 @@ limitations under the License. ==============================================================================*/ // Pattern definition file for direct lowering of mhlo composites to tflite ops. - include "mlir/IR/OpBase.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mhlo/IR/hlo_ops.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" - +include "tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td" +include "tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool_patterns.td" def LegalizeHardSwishComposite: Pat< - (MHLO_CompositeOp:$old_value + (MHLO_CompositeOp:$old_val (variadic $input), ConstantStrAttr, $_, $_, $_), (TFL_HardSwishOp $input)>; + +// Checks if the given op is an InterpolateBilinear op with NCHW layout. +// Supplied arguments are the input, output op values and the output shape. +def IsSupportedNchwUpsampleBlinear: Constraint())">>; + +def LegalizeTorchUpsampleBlinear2dComposite: Pat< + (MHLO_CompositeOp:$old_val + (variadic $input), + ConstantStrAttr, $attrs, $_, $_), + (TFL_TransposeOp + (TFL_ResizeBilinearOp + (TFL_TransposeOp $input, + (Arith_ConstantOp + ConstantAttr,"{0, 2, 3, 1}">)), + (Arith_ConstantOp:$output_size (GetI32DenseAttr (GetAsVectorAttr<"output"> $attrs))), + (GetCompositeAttributeAs<"align_corners", "BoolAttr"> $attrs), + ConstBoolAttrTrue, + (returnType (GetNhwcReturnTypeFromNchw $old_val))), + (Arith_ConstantOp + ConstantAttr,"{0, 3, 1, 2}">)), + [(IsSupportedNchwUpsampleBlinear $input, $old_val, $attrs)]>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.cc new file mode 100644 index 00000000000000..403bf9968a9acd --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.cc @@ -0,0 +1,105 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h" + +#include +#include +#include +#include +#include +#include + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project + +namespace mlir { +namespace odml { + +DenseIntElementsAttr DenseI64AttrToI32Attr( + const DenseIntElementsAttr& dense_attr, PatternRewriter& builder) { + std::vector ret(dense_attr.getNumElements()); + auto range = dense_attr.getValues(); + std::transform(range.begin(), range.end(), ret.begin(), + [](int64_t attr) { return static_cast(attr); }); + return DenseIntElementsAttr::get( + RankedTensorType::get(ret.size(), builder.getIntegerType(32)), ret); +} + +bool DenseI64AttrToI32Vector(const DenseIntElementsAttr& dense_attr, + std::vector* out_vec) { + std::vector ret(dense_attr.getNumElements()); + auto range = dense_attr.getValues(); + std::transform(range.begin(), range.end(), ret.begin(), + [](int64_t attr) { return static_cast(attr); }); + *out_vec = std::move(ret); + return true; +} + +bool GetI32VectorFromDenseI64CompositeAttr( + const DictionaryAttr& composite_attrs, const std::string& attr_name, + std::vector* out_vec) { + DenseIntElementsAttr attr; + if (!EnsureAttribute(composite_attrs, attr_name, + &attr)) { + return false; + } + + return DenseI64AttrToI32Vector(attr, out_vec); +} + +bool IsSupportedNchwUpsampleBlinear( + Value input, Value output, const DenseIntElementsAttr& output_size_attr) { + auto input_shape = input.getType().cast().getShape(); + auto output_shape = output.getType().cast().getShape(); + + // Only support 4D tensor. + if (input_shape.size() != 4 || output_shape.size() != 4) { + return false; + } + + // Only expects the first two dimensions of input and output to be the same as + // in NCHW. + if (input_shape[0] != output_shape[0] || input_shape[1] != output_shape[1]) { + return false; + } + + // Supplied output size should be 2D. + if (output_size_attr.getNumElements() != 2) { + return false; + } + auto output_size = output_size_attr.getValues(); + return output_size[0] == output_shape[2] && output_size[1] == output_shape[3]; +} + +ShapedType GetNhwcReturnTypeFromNchw(Operation* old_op) { + auto composite_result_shape = + old_op->getResults().front().getType().cast().getShape(); + std::array output_shape; + // NHWC <- NCHW + output_shape[0] = composite_result_shape[0]; + output_shape[1] = composite_result_shape[2]; + output_shape[2] = composite_result_shape[3]; + output_shape[3] = composite_result_shape[1]; + + auto input_type = old_op->getOperand(0).getType().cast(); + + return RankedTensorType::get(output_shape, input_type.getElementType()); +} +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h new file mode 100644 index 00000000000000..0691dc74997212 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h @@ -0,0 +1,83 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_UTILS_H_ + +#include +#include +#include + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep + +namespace mlir { +namespace odml { + +// Ensure an attribute named attr_name exists and it is of type AttrType. +// If so, sets the `out_attr` pointer to point to the casted attribute. +template +bool EnsureAttribute(const DictionaryAttr& composite_attributes, + const std::string& attr_name, AttrType* out_attr) { + Attribute attr = composite_attributes.get(attr_name); + if (!attr.isa_and_nonnull()) { + return false; + } + if (AttrType content = attr.dyn_cast()) { + *out_attr = content; + return true; + } else { + return false; + } +} + +// Changes a DenseIntElementsAttr **containing I64** elements to an I32 Vector. +bool DenseI64AttrToI32Vector(const DenseIntElementsAttr& dense_attr, + std::vector* out_vec); + +// Given a DictionaryAttr, checks if it has a DenseIntElementsAttr attribute +// with the name attr_name. If so, extracts its values and stores as a vector +// of int32_t elements. +// Note: This assumes the DenseIntElementsAttr has its values stored as int64_t. +bool GetI32VectorFromDenseI64CompositeAttr( + const DictionaryAttr& composite_attrs, const std::string& attr_name, + std::vector* out_vec); + +// Get a DenseIntElementsAttr of type I64 and convert it to an I32 attribute. +DenseIntElementsAttr DenseI64AttrToI32Attr( + const DenseIntElementsAttr& dense_attr, PatternRewriter& builder); + +// Returns true if the given input and output are in NCHW layout +bool IsSupportedNchwUpsampleBlinear( + Value input, Value output, const DenseIntElementsAttr& output_size_attr); + +// Returns a NHWC shaped type from an NCHW shaped type op. +// For example- Given a Composite op that wraps a core.aten.avg_pool2d, this +// returns the return type of the tfl.average_pool_2d emitted. Note that the +// aten.avg_pool2d works with the NCHW layout while tfl.average_pool_2d assumes +// NHWC. +ShapedType GetNhwcReturnTypeFromNchw(Operation* old_op); + +} // namespace odml + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td new file mode 100644 index 00000000000000..d39a8efb8b13b3 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td @@ -0,0 +1,64 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef COMPOSITE_UTILS_TD +#define COMPOSITE_UTILS_TD + +include "mlir/IR/PatternBase.td" + +// See the function doc in the header file. +def GetNhwcReturnTypeFromNchw: NativeCodeCall< + "GetNhwcReturnTypeFromNchw((*$0.begin()).getDefiningOp())">; + + +// When given a DenseIntElementsAttr containing I64 elements, this extracts +// one I32IntegerAttr from the given index. +class GetI32At: NativeCodeCall< + "$_builder.getI32IntegerAttr(static_cast(*($0.getValues().begin() + " # index #")))">; + +def GetI32DenseAttr: NativeCodeCall< + "DenseI64AttrToI32Attr($0, $_builder)">; + +// Receives a composite DictionaryAttr and returns the value of the Attribute +// with the key `attr_name` as the type provided by `attr_type`. +class GetCompositeAttributeAs: + NativeCodeCall<"$0.get(\"" # attr_name # "\").dyn_cast<" # attr_type # ">()">; + +// Receives a composite DictionaryAttr and returns the value of the Attribute +// with the key `attr_name` as a DenseIntElementsAttr. +class GetAsVectorAttr: + GetCompositeAttributeAs; + +class IsBoolAttrEqual : Constraint>; + +// Receives a composite DictionaryAttr as an argument and checks if one of the +// its attributes (with the name `attr_name`) is of type `attribute` and has +// the value `val`. +class IsCompositeAttribute: + Constraint>; + +// Receives a composite DictionaryAttr as an argument and checks if has a +// BoolAttr with the name `attr_name` and value `val`. +class IsBoolCompositeAttribute : + IsCompositeAttribute; + +// Receives a composite DictionaryAttr as an argument and checks if has a +// StrAttr with the name `attr_name` and value `val`. +class IsStrCompositeAttribute : + IsCompositeAttribute; + +#endif diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc index a35f5ba324e3f4..6e0a3325460b7a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -47,6 +48,10 @@ bool IsSupportedComposite(::mlir::stablehlo::CompositeOp op) { op.getName()); } +bool IsKVCacheCompositeOp(::mlir::stablehlo::CompositeOp op) { + return op.getName() == "odml.update_kv_cache"; +} + TFL::ConstBytesAttr CustomOption(OpBuilder* builder, const std::string& content) { return TFL::ConstBytesAttr::get(builder->getContext(), @@ -75,6 +80,12 @@ TFL::CustomOp BuildCustomOp(stablehlo::CompositeOp composite, const std::string& custom_option_buffer) { OpBuilder builder(composite->getContext()); builder.setInsertionPoint(composite); + if (IsKVCacheCompositeOp(composite)) { + return builder.create( + composite->getLoc(), composite->getResultTypes(), + composite->getOperands().slice(2, 3), composite.getName(), + CustomOption(&builder, custom_option_buffer)); + } return builder.create( composite->getLoc(), composite->getResultTypes(), composite->getOperands(), composite.getName(), @@ -104,11 +115,48 @@ struct LegalizeCompositeToCustomOpPass void runOnOperation() override { func::FuncOp fn = getOperation(); + + int num_layers = 0, current_layer_index = 0; + // First walk the function to count number of KV Caches. + fn.walk([&](Operation* op) { + auto composite = llvm::dyn_cast(op); + if (!composite || !IsKVCacheCompositeOp(composite)) return; + num_layers++; + }); + fn.walk([&](Operation* op) { // Process only StableHLO composite ops. auto composite = llvm::dyn_cast(op); if (!composite || !IsSupportedComposite(composite)) return; + if (IsKVCacheCompositeOp(composite)) { + auto comp_attr = composite.getCompositeAttributes(); + mlir::Builder builder(composite->getContext()); + + // num_layers Composite Attribute. + mlir::StringAttr num_layers_str = builder.getStringAttr("num_layers"); + NamedAttribute num_layers_attr( + num_layers_str, + IntegerAttr::get(IntegerType::get(fn.getContext(), /*width=*/32), + num_layers)); + + // current_layer_index Composite Attribute. + mlir::StringAttr current_layer_str = + builder.getStringAttr("layer_index"); + NamedAttribute current_layer_attr( + current_layer_str, + IntegerAttr::get(IntegerType::get(fn.getContext(), /*width=*/32), + current_layer_index++)); + + // Build a new CompositeAttributes attr, add in the above, + // and set for the op. + mlir::NamedAttrList attributes(comp_attr); + attributes.append(num_layers_attr); + attributes.append(current_layer_attr); + comp_attr = attributes.getDictionary(builder.getContext()); + composite.setCompositeAttributesAttr(comp_attr); + } + // Build flexbuffer options. std::string custom_option_buffer; auto fbb = std::make_unique(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc index fcacfcf4984db1..a22c392163b09d 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc @@ -52,17 +52,17 @@ class TFToMhloPass public: explicit TFToMhloPass(bool skip_quantization_ops = false, bool skip_resize = false, - bool skip_stateful_partitioned_call = false) + bool skip_partitioned_calls = false) : PassWrapper() { skip_quantization_ops_ = skip_quantization_ops; skip_resize_ = skip_resize; - skip_stateful_partitioned_call_ = skip_stateful_partitioned_call; + skip_partitioned_calls_ = skip_partitioned_calls; } TFToMhloPass(const TFToMhloPass &pass) { skip_quantization_ops_ = pass.skip_quantization_ops_; skip_resize_ = pass.skip_resize_; - skip_stateful_partitioned_call_ = pass.skip_stateful_partitioned_call_; + skip_partitioned_calls_ = pass.skip_partitioned_calls_; } private: @@ -90,9 +90,10 @@ class TFToMhloPass *this, "skip-resize", ::llvm::cl::desc("Skip tf.ResizeBilinear and tf.ResizeNearestNeighbor")}; - Option skip_stateful_partitioned_call_{ - *this, "skip-stateful-partitioned-call", - ::llvm::cl::desc("Skip tf.StatefulPartitionedCall")}; + Option skip_partitioned_calls_{ + *this, "skip-partitioned-calls", + ::llvm::cl::desc( + "Skip tf.StatefulPartitionedCall and tf.PartitionedCall")}; }; void TFToMhloPass::runOnOperation() { @@ -129,7 +130,8 @@ void TFToMhloPass::runOnOperation() { target.addLegalOp(); target.addLegalOp(); } - if (skip_stateful_partitioned_call_) { + if (skip_partitioned_calls_) { + target.addLegalOp(); target.addLegalOp(); } @@ -145,9 +147,10 @@ struct TFToStablehloOptions : public PassPipelineOptions { Option skip_resize{ *this, "skip-resize", ::llvm::cl::desc("Skip tf.ResizeBilinear and tf.ResizeNearestNeighbor")}; - Option skip_stateful_partitioned_call{ - *this, "skip-stateful-partitioned-call", - ::llvm::cl::desc("Skip tf.StatefulPartitionedCall")}; + Option skip_partitioned_calls{ + *this, "skip-partitioned-calls", + ::llvm::cl::desc( + "Skip tf.StatefulPartitionedCall and tf.PartitionedCall")}; }; void PopulateLegalizeTFToStablehloPipeline( @@ -157,7 +160,7 @@ void PopulateLegalizeTFToStablehloPipeline( // reusing their work, perhaps through `LowerToMlProgramAndHlo`. pm.addNestedPass(std::make_unique( options.skip_quantization_ops, options.skip_resize, - options.skip_stateful_partitioned_call)); + options.skip_partitioned_calls)); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mhlo::createHloLegalizeToStablehloPass()); } @@ -170,11 +173,11 @@ static PassPipelineRegistration void AddLegalizeTFToStablehloPasses(OpPassManager &pm, bool skip_quantization_ops, bool skip_resize, - bool skip_stateful_partitioned_call) { + bool skip_partitioned_calls) { TFToStablehloOptions options; options.skip_quantization_ops = skip_quantization_ops; options.skip_resize = skip_resize; - options.skip_stateful_partitioned_call = skip_stateful_partitioned_call; + options.skip_partitioned_calls = skip_partitioned_calls; PopulateLegalizeTFToStablehloPipeline(pm, options); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h index 91eafb5ab7fa49..c26a3f36daf675 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h @@ -25,7 +25,7 @@ namespace odml { void AddLegalizeTFToStablehloPasses(OpPassManager& pm, bool skip_quantization_ops, bool skip_resize, - bool skip_stateful_partitioned_call); + bool skip_partitioned_calls); } // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index 8fed8f3f01ed54..ad3bc3cd4cd24d 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -61,8 +61,10 @@ using ::mlir::quant::CreateI32F32UniformQuantizedPerAxisType; using ::mlir::quant::CreateI32F32UniformQuantizedType; using ::mlir::quant::CreateI8F32UniformQuantizedPerAxisType; using ::mlir::quant::CreateI8F32UniformQuantizedType; +using ::mlir::quant::FindOperandOfType; using ::mlir::quant::FindUserOfType; using ::mlir::quant::GetElementType; +using ::mlir::quant::IsDotGeneralFullyConnected; using ::mlir::quant::IsI32F32UniformQuantizedPerAxisType; using ::mlir::quant::IsI32F32UniformQuantizedType; using ::mlir::quant::IsI8F32UniformQuantizedPerAxisType; @@ -107,6 +109,20 @@ double GetBiasScale(const double input_scale, const double filter_scale) { return filter_scale * input_scale; } +// Returns the optionally broadcasted bias constant op used for a given op. +// If no such constant op exists, returns a nullptr. +Operation* GetBiasConstOp(Operation* op) { + Operation* bias_const_op; + if (Operation* broadcast_in_dim_op = + FindOperandOfType(op); + broadcast_in_dim_op != nullptr) { + bias_const_op = broadcast_in_dim_op->getOperand(0).getDefiningOp(); + } else { + bias_const_op = FindOperandOfType(op); + } + return isa(bias_const_op) ? bias_const_op : nullptr; +} + // Creates a new `tfl.qconst` op for the quantized filter. Transposes the // filter value from [i, o] -> [o, i]. This is because we assume `[i, o]` // format for `stablehlo.dot_general` (i.e. contracting dimension == 1) @@ -426,8 +442,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp LogicalResult match(stablehlo::DotGeneralOp op) const override { const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = op.getDotDimensionNumbers(); - const bool is_batch_matmul = - !dot_dimension_nums.getLhsBatchingDimensions().empty(); + const bool is_batch_matmul = !IsDotGeneralFullyConnected(op).value(); const Type elem_type = GetElementType(op.getResult()); const bool has_i32_output = IsI32F32UniformQuantizedType(elem_type) || IsI32F32UniformQuantizedPerAxisType(elem_type); @@ -464,8 +479,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp IsI32F32UniformQuantizedPerAxisType(output_type); const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = op.getDotDimensionNumbers(); - const bool is_batch_matmul = - !dot_dimension_nums.getLhsBatchingDimensions().empty(); + const bool is_batch_matmul = !IsDotGeneralFullyConnected(op).value(); if (is_batch_matmul) { RewriteDotGeneralToTflBatchMatmulOp(op, rewriter, dot_dimension_nums, @@ -793,15 +807,17 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp .cast() .getZeroPoints(), /*quantization_dimension=*/0); - Operation* stablehlo_bias_op = add_op->getOperand(1).getDefiningOp(); - const auto bias_type = RankedTensorType::getChecked( - op->getLoc(), bias_shape, bias_quantized_type); - const auto bias_value = cast( - cast(stablehlo_bias_op).getValue()); - - *bias_tfl_op = rewriter.create( - op->getLoc(), - /*output=*/TypeAttr::get(bias_type), /*value=*/bias_value); + Operation* bias_const_op = GetBiasConstOp(add_op); + if (bias_const_op != nullptr) { + const auto bias_type = RankedTensorType::getChecked( + op->getLoc(), bias_shape, bias_quantized_type); + const auto bias_value = cast( + cast(bias_const_op).getValue()); + + *bias_tfl_op = rewriter.create( + op->getLoc(), + /*output=*/TypeAttr::get(bias_type), /*value=*/bias_value); + } } else { uniform_quantize_op = FindUserOfType(op); } @@ -902,22 +918,14 @@ class RewriteQuantizedConvolutionOp return failure(); } - // TODO: b/309896242 - Lift the assumptions on adjacent ops below - // as we cover more dynamic fused pattern legalization. if (fuse_bias_constant) { Operation* add_op = FindUserOfType(op); if (add_op == nullptr) { LLVM_DEBUG(llvm::dbgs() << "Failed to find AddOp for bias fusion.\n"); return failure(); } - Operation* broadcast_in_dim_op = add_op->getOperand(1).getDefiningOp(); - if (!isa(broadcast_in_dim_op)) { - LLVM_DEBUG(llvm::dbgs() << "Failed to find broadcasted bias.\n"); - return failure(); - } - Operation* bias_const_op = - broadcast_in_dim_op->getOperand(0).getDefiningOp(); - if (!isa(bias_const_op)) { + Operation* bias_const_op = GetBiasConstOp(add_op); + if (bias_const_op == nullptr) { LLVM_DEBUG(llvm::dbgs() << "Failed to find bias constant.\n"); return failure(); } @@ -1413,11 +1421,7 @@ class RewriteQuantizedConvolutionOp TFL::QConstOp bias; if (fuse_bias_constant && has_i32_output) { Operation* add_op = FindUserOfType(op); - // TODO: b/309896242 - Lift the assumptions on adjacent ops below - // as we cover more dynamic fused pattern legalization. - Operation* broadcast_in_dim_op = add_op->getOperand(1).getDefiningOp(); - Operation* bias_const_op = - broadcast_in_dim_op->getOperand(0).getDefiningOp(); + Operation* bias_const_op = GetBiasConstOp(add_op); const ElementsAttr bias_constant_value = cast(bias_const_op).getValue(); bias = rewriter.create(op.getLoc(), diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index ba9c1e58565286..f244d15294c253 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -181,6 +181,46 @@ func.func @elementwise_unary_ops() -> (tensor, tensor, tensor, te func.return %7, %8, %9, %10, %11, %12, %13 : tensor, tensor, tensor, tensor, tensor, tensor, tensor } +// CHECK-LABEL: @max_with_neg_f32_max_val +// CHECK-SAME: (%[[ARG0:.+]]: tensor) +func.func @max_with_neg_f32_max_val(%arg0 : tensor) -> (tensor, tensor) { + %neg_f32_max = arith.constant dense<-3.40282347E+38> : tensor + %0 = "tfl.maximum"(%arg0, %neg_f32_max) : (tensor, tensor) -> tensor + %1 = "tfl.maximum"(%neg_f32_max, %arg0) : (tensor, tensor) -> tensor + func.return %0, %1 : tensor, tensor + // CHECK: return %[[ARG0]], %[[ARG0]] +} + +// CHECK-LABEL: @min_with_f32_max_val +// CHECK-SAME: (%[[ARG0:.+]]: tensor) +func.func @min_with_f32_max_val(%arg0 : tensor) -> (tensor, tensor) { + %f32_max = arith.constant dense<3.40282347E+38> : tensor + %0 = "tfl.minimum"(%arg0, %f32_max) : (tensor, tensor) -> tensor + %1 = "tfl.minimum"(%f32_max, %arg0) : (tensor, tensor) -> tensor + func.return %0, %1 : tensor, tensor + // CHECK: return %[[ARG0]], %[[ARG0]] +} + +// CHECK-LABEL: @max_with_neg_f64_max_val +// CHECK-SAME: (%[[ARG0:.+]]: tensor) +func.func @max_with_neg_f64_max_val(%arg0 : tensor) -> (tensor, tensor) { + %neg_f64_max = arith.constant dense<-1.7976931348623157E+308> : tensor + %0 = "tfl.maximum"(%arg0, %neg_f64_max) : (tensor, tensor) -> tensor + %1 = "tfl.maximum"(%neg_f64_max, %arg0) : (tensor, tensor) -> tensor + func.return %0, %1 : tensor, tensor + // CHECK: return %[[ARG0]], %[[ARG0]] +} + +// CHECK-LABEL: @min_with_f64_max_val +// CHECK-SAME: (%[[ARG0:.+]]: tensor) +func.func @min_with_f64_max_val(%arg0 : tensor) -> (tensor, tensor) { + %f64_max = arith.constant dense<1.7976931348623157E+308> : tensor + %0 = "tfl.minimum"(%arg0, %f64_max) : (tensor, tensor) -> tensor + %1 = "tfl.minimum"(%f64_max, %arg0) : (tensor, tensor) -> tensor + func.return %0, %1 : tensor, tensor + // CHECK: return %[[ARG0]], %[[ARG0]] +} + // CHECK-LABEL: @mul_int func.func @mul_int() -> (tensor, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { %0 = arith.constant dense<8> : tensor diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index aeb14ece5e26b7..3c83883dfea9a4 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -881,6 +881,25 @@ func.func @QuantizedCatsAddRequantsTest(%arg0: tensor<1x1xf32>, %arg1: tensor<1x // CHECK-NEXT: return %[[dqcat_2_0_1_0]], %[[dqcat_2_0_3]] : tensor<1x4xf32>, tensor<1x3xf32> } +// QDQ-LABEL: TransposePerTensorQuantizationPropagation +func.func @TransposePerTensorQuantizationPropagation() -> tensor<2x5xf32> { + %perm = arith.constant dense<[1, 0]> : tensor<2xi32> + %cst = arith.constant dense<1.0> : tensor<5x2xf32> + %q = "tfl.quantize"(%cst) {qtype = tensor<5x2x!quant.uniform:f32, 1.113490e-03>>} : (tensor<5x2xf32>) -> tensor<5x2x!quant.uniform:f32, 1.113490e-03>> + %dq = "tfl.dequantize"(%q) : (tensor<5x2x!quant.uniform:f32, 1.113490e-03>>) -> tensor<5x2xf32> + %t = "tfl.transpose"(%dq, %perm) : (tensor<5x2xf32>, tensor<2xi32>) -> tensor<2x5xf32> + func.return %t : tensor<2x5xf32> + + // QDQ: %[[perm:.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> + // QDQ-NEXT: %[[w:.*]] = arith.constant dense<1.000000e+00> : tensor<5x2xf32> + // QDQ-NEXT: %[[qw:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<5x2x!quant.uniform:f32 + // QDQ-NEXT: %[[dqw:.*]] = "tfl.dequantize"(%[[qw]]) : (tensor<5x2x!quant.uniform:f32 + // QDQ-NEXT: %[[tp:.*]] = "tfl.transpose"(%[[dqw]], %[[perm]]) : (tensor<5x2xf32>, tensor<2xi32>) -> tensor<2x5xf32> + // QDQ-NEXT: %[[qtw:.*]] = "tfl.quantize"(%[[tp]]) {qtype = tensor<2x5x!quant.uniform:f32 + // QDQ-NEXT: %[[dqtw:.*]] = "tfl.dequantize"(%[[qtw]]) : (tensor<2x5x!quant.uniform:f32 + // QDQ-NEXT: return %[[dqtw]] : tensor<2x5xf32> +} + // QDQ-LABEL: TransposePerChannelNewQuantDim func.func @TransposePerChannelNewQuantDim() -> tensor<2x5xf32> { %perm = arith.constant dense<[1, 0]> : tensor<2xi32> diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 401f34e6e7943c..38a8bffd87bb03 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -1095,9 +1095,9 @@ void LegalizeTFPass::runOnOperation() { addPatterns(context, stage1Patterns, this->preserve_assert_op_); FrozenRewritePatternSet stage1FrozenPatterns(std::move(stage1Patterns)); - if (!applyPatterns(func, target, stage1FrozenPatterns)) + if (!applyPatterns(func, target, stage1FrozenPatterns)) { return signalPassFailure(); - + } // Explict BroadcastTo addition for left-over broadcast-able ops. // The following pattern matchings should be done after the other legalization // rules in order not to add unnecessary BroadcastTo ops. @@ -1126,8 +1126,9 @@ void LegalizeTFPass::runOnOperation() { ApplyExplicitBroadcasting>(context); FrozenRewritePatternSet stage2FrozenPatterns(std::move(stage2Patterns)); - if (!applyPatterns(func, target, stage2FrozenPatterns)) + if (!applyPatterns(func, target, stage2FrozenPatterns)) { return signalPassFailure(); + } } } // namespace diff --git a/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc b/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc index 2ed12a34059588..e212ce16ee6ccd 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc @@ -219,7 +219,7 @@ class LiftFlexCustomOp : public OpRewritePattern { for (const auto& name_and_value : node_def.attr()) { const std::string& attr_name = name_and_value.first; const tensorflow::AttrValue& attr_value = name_and_value.second; - StatusOr mlir_attr = + absl::StatusOr mlir_attr = tensorflow::ConvertAttributeValue(attr_value, &builder); if (!mlir_attr.ok()) { return emitError(loc, mlir_attr.status().message()); diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index e09030ceb7515f..f2e659b9aea9ce 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -27,7 +27,7 @@ limitations under the License. namespace tflite { -using xla::StatusOr; +using absl::StatusOr; namespace errors = tensorflow::errors; diff --git a/tensorflow/compiler/mlir/quantization/common/BUILD b/tensorflow/compiler/mlir/quantization/common/BUILD index 8091fe21ef56ff..da122b67993af7 100644 --- a/tensorflow/compiler/mlir/quantization/common/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/BUILD @@ -57,6 +57,7 @@ tf_cc_test( name = "lift_as_function_call_test", srcs = ["lift_as_function_call_test.cc"], deps = [ + ":attrs_and_constraints", ":func", ":lift_as_function_call", ":test_base", @@ -148,6 +149,8 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -164,12 +167,14 @@ tf_cc_test( ":func", ":test_base", "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:status_matchers", "@stablehlo//:stablehlo_ops", ], ) diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc index b1f579bc8e71b8..540eff26685968 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" @@ -119,9 +121,11 @@ bool IsHybridQuantizedOp(Operation* op) { !IsQuantizedTensorType(result_type); } -std::optional GetDotGeneralQuantizationDim( - DotGeneralOp dot_general_op) { - if (dot_general_op == nullptr) return std::nullopt; +absl::StatusOr IsDotGeneralFullyConnected(DotGeneralOp dot_general_op) { + if (dot_general_op == nullptr) + return absl::InvalidArgumentError( + "Given dot_general op cannot be null when checking " + "`IsDotGeneralBatchMatmul`."); const ::mlir::stablehlo::DotDimensionNumbersAttr dot_dimension_numbers = dot_general_op.getDotDimensionNumbers(); const ArrayRef lhs_contracting_dims = @@ -132,10 +136,8 @@ std::optional GetDotGeneralQuantizationDim( dot_general_op.getOperand(0).getType().dyn_cast().getRank(); const int64_t filter_rank = dot_general_op.getOperand(1).getType().dyn_cast().getRank(); - // To quantize rhs per-channel, we currently only consider the case where - // `stablehlo.dot_general` is legalizable to `tfl.fully_connected`. // The following conditions are such requirements: - // - rank(lhs) <= 2 + // - rank(lhs) is 1 or 2 // - rank(rhs) = 2 // - size(lhs_contracting_dimensions) = 1 // - size(rhs_contracting_dimensions) = 1 @@ -144,7 +146,8 @@ std::optional GetDotGeneralQuantizationDim( // - quantization_dimension(rhs) should not be in // `rhs_contracting_dimensions`. // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general - const bool has_proper_rank = input_rank <= 2 && filter_rank == 2; + const bool has_proper_rank = + (input_rank == 1 || input_rank == 2) && filter_rank == 2; const bool has_proper_contracting_dim = lhs_contracting_dims.size() == 1 && rhs_contracting_dims.size() == 1 && lhs_contracting_dims[0] == input_rank - 1; @@ -153,9 +156,20 @@ std::optional GetDotGeneralQuantizationDim( const bool has_proper_quantization_dimension = absl::c_find(rhs_contracting_dims, filter_rank) == rhs_contracting_dims.end(); + return has_proper_rank && has_proper_contracting_dim && is_not_batch_op && + has_proper_quantization_dimension; +} + +std::optional GetDotGeneralQuantizationDim( + DotGeneralOp dot_general_op) { + if (dot_general_op == nullptr) return std::nullopt; + const int64_t filter_rank = + dot_general_op.getOperand(1).getType().dyn_cast().getRank(); + + // To quantize rhs per-channel, we currently only consider the case where + // `stablehlo.dot_general` is legalizable to `tfl.fully_connected`. const bool is_per_axis_quantizable = - has_proper_rank && has_proper_contracting_dim && is_not_batch_op && - has_proper_quantization_dimension; + IsDotGeneralFullyConnected(dot_general_op).value(); if (!is_per_axis_quantizable) return std::nullopt; return filter_rank - 1; } diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h index 852902e229a9fc..490a77a3b73ffa 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -178,6 +179,15 @@ FailureOr CastI64ToI32(int64_t value); FailureOr> CastI64ArrayToI32( ArrayRef int64_array); +// Returns the first operation with the given type in the function. +template +OpType FindOperationOfType(func::FuncOp function) { + for (auto op : function.getBody().getOps()) { + return op; + } + return nullptr; +} + // Returns the first user of the given operation, optionally of the given // type if provided. If there is no user or user of type, return nullptr. template @@ -190,6 +200,18 @@ Operation* FindUserOfType(Operation* op) { return nullptr; } +// Returns the first user of the given operation, optionally of the given +// type if provided. If there is no user or user of type, return nullptr. +template +Operation* FindOperandOfType(Operation* op) { + for (Value operand_value : op->getOperands()) { + if (isa(operand_value.getDefiningOp())) { + return operand_value.getDefiningOp(); + } + } + return nullptr; +} + // Returns the function attribute for the given call op which is lifted for // quantization. inline FlatSymbolRefAttr GetFuncAttr(TF::PartitionedCallOp call_op) { @@ -216,6 +238,11 @@ inline bool HasQuantizableTrait(Operation* op) { // is quantized. bool IsHybridQuantizedOp(Operation* op); +// Returns whether a given `stablehlo.dot_general` can be legalizable to +// `tfl.fully_connected`. +absl::StatusOr IsDotGeneralFullyConnected( + ::mlir::stablehlo::DotGeneralOp dot_general_op); + // Returns the quantization dimension for a given `stablehlo.dot_general` op, // or `std::nullopt` if the given op is not per-channel quantizable. std::optional GetDotGeneralQuantizationDim( diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc index f6e633aa4c7861..ca0df77f81b51c 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -33,11 +34,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/common/test_base.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tsl/platform/status_matchers.h" namespace mlir::quant { namespace { using ::mlir::stablehlo::AddOp; +using ::mlir::stablehlo::ConstantOp; using ::mlir::stablehlo::ConvolutionOp; using ::mlir::stablehlo::DotGeneralOp; using ::mlir::stablehlo::SubtractOp; @@ -47,6 +50,7 @@ using ::testing::IsEmpty; using ::testing::IsNull; using ::testing::NotNull; using ::testing::Optional; +using ::tsl::testing::StatusIs; using AttrsAndConstraintsTest = ::mlir::quant::QuantizationTestBase; @@ -70,10 +74,11 @@ constexpr absl::string_view kModuleDynamic = R"mlir( constexpr absl::string_view kModuleMultipleUses = R"mlir( module { - func.func @main(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + func.func @main(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %cst = stablehlo.constant dense<1.0> : tensor<1x3xf32> %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> - %1 = stablehlo.subtract %arg2, %0 : tensor<1x3xf32> - %2 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + %1 = stablehlo.subtract %cst, %0 : tensor<1x3xf32> + %2 = stablehlo.add %0, %cst : tensor<1x3xf32> return %2 : tensor<1x3xf32> } } @@ -326,6 +331,22 @@ TEST_F(AttrsAndConstraintsTest, FindUserOfDifferentTypes) { EXPECT_THAT(FindUserOfType(dot_general_op), IsNull()); } +TEST_F(AttrsAndConstraintsTest, FindOperandOfDifferentTypes) { + OwningOpRef module_op = ParseModuleOpString(kModuleMultipleUses); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto subtract_op = FindOperationOfType(main_fn); + ASSERT_THAT(subtract_op, NotNull()); + + EXPECT_THAT(FindOperandOfType(subtract_op), NotNull()); + EXPECT_THAT(FindOperandOfType(subtract_op), NotNull()); + EXPECT_THAT(FindOperandOfType<>(subtract_op), NotNull()); + EXPECT_THAT(FindOperandOfType(subtract_op), IsNull()); +} + TEST_F(AttrsAndConstraintsTest, XlaCallModuleOpGetFuncAttr) { OwningOpRef module_op = ParseModuleOpString(kModuleXlaCallModule); ASSERT_TRUE(module_op); @@ -450,6 +471,37 @@ constexpr absl::string_view kModuleDotGeneralBatchMatmul = R"mlir( } )mlir"; +TEST_F(AttrsAndConstraintsTest, IsDotGeneralFullyConnectedReturnsError) { + DotGeneralOp dot_general_op = nullptr; + StatusIs(absl::StatusCode::kInvalidArgument, + "Given dot_general op cannot be null when checking " + "`IsDotGeneralBatchMatmul`"); +} + +TEST_F(AttrsAndConstraintsTest, IsDotGeneralFullyConnectedReturnsTrue) { + OwningOpRef module_op = + ParseModuleOpString(kModuleDotGeneralFullyConnected); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto dot_general_op = *main_fn.getOps().begin(); + EXPECT_THAT(IsDotGeneralFullyConnected(dot_general_op), true); +} + +TEST_F(AttrsAndConstraintsTest, IsDotGeneralFullyConnectedReturnsFalse) { + OwningOpRef module_op = + ParseModuleOpString(kModuleDotGeneralBatchMatmul); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto dot_general_op = *main_fn.getOps().begin(); + EXPECT_THAT(IsDotGeneralFullyConnected(dot_general_op), false); +} + TEST_F(AttrsAndConstraintsTest, DotGeneralFullyConnectedReturnsQuantDim) { OwningOpRef module_op = ParseModuleOpString(kModuleDotGeneralFullyConnected); diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h index bd7421d376102b..bfef9a13df1a01 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h @@ -34,6 +34,10 @@ inline constexpr StringRef kFusedFunctionAttr = "tf_quant.composite_function"; // The keyword to detect if this is a `NullAttribute`. inline constexpr StringRef kNullAttributeValue = "N/A"; +// Prefixes attached to lifted functions. +constexpr StringRef kQuantizedFuncPrefix = "quantized_"; +constexpr StringRef kCompositeFuncPrefix = "composite_"; + // The attribute will be used for TF::XlaCallModuleOp to restore the original // function name when loading it back. inline constexpr StringRef kOriginalStablehloEntryFunctionAttrName = diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc index c37a997217d2b7..5e5e103ba72018 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/common/test_base.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc index 327d109946e031..216a4a2b3d58e9 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc @@ -87,6 +87,21 @@ void InitializeStateForValue( cached->second = next_state_index; } +bool HasPerAxisQuantizedOperand(Operation* op) { + for (int i = 0; i < op->getNumOperands(); ++i) { + if (auto dq_op = dyn_cast_or_null( + op->getOperand(i).getDefiningOp())) { + auto type = dq_op.getArg().getType().cast().getElementType(); + if (auto per_axis_qtype = + QuantizedType::getQuantizedElementType(type) + .dyn_cast_or_null()) { + return true; + } + } + } + return false; +} + } // namespace void QuantizationDriver::InitializeArgState(const BlockArgument arg, @@ -480,7 +495,10 @@ void QuantizationDriver::PreprocessConstantOps() { // Skip if the value is NaN or INF. // Otherwise the illegal scale/zp will be calculated. auto float_attr = cst.getValueAttr().dyn_cast(); - if (float_attr && !float_attr.getValues()[0].isFinite()) return; + if (float_attr && (float_attr.getValues().empty() || + !float_attr.getValues()[0].isFinite())) { + return; + } const Value value = cst.getResult(); builder_.setInsertionPoint(cst); @@ -788,11 +806,18 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { // quantization for the quantized kernel. If the quantized dimension // changes, the following logic no longer works as the same `params` // shouldn't be used for both input and output quantization params. - // E.g. TransposeOp's propagation is handled in - // `PropagateTransposedQuantDim` in PrepareQuantize. + // E.g. During TransposeOp's quantization propagation in + // PrepareQuantize, if the quantization is per-axis and the + // QuantizedDimension is transposed, then the output q-dq params must + // reflect the new QuantizedDimension. So, check and skip the + // propagation if any of the operands has a per-axis quantized type param + // and `RequiredSameQuantizedAxes` set to false. + // Currently, these lines of code are only applicable to TFL_TransposeOp + // and the output q-dq propagation for this Op is performed in + // `PropagateTransposedPerAxisQuantDim`. if (is_qdq_conversion_ && !scale_spec->required_same_quantized_axes_func()) { - continue; + if (HasPerAxisQuantizedOperand(op)) continue; } // Use the final state to set all the operands' parameters. diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h index e1d36df58a3fd9..453dc419371932 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h @@ -134,11 +134,14 @@ using OpQuantSpecGetter = // Quantization scale spec of an op. The information defined in the MLIR // interfaces FixedOutputRangeInterface and SameOperandsAndResultsScale should // be checked first if present. +// TODO: b/323478683: Consider deprecating this. struct OpQuantScaleSpec { // Whether this op has a fixed range requirement (e.g. sigmoid) bool has_fixed_output_range = false; - // Whether this op should have same result and operand scales (e.g. concat) + // Whether this op should have same operand and result scales (e.g. concat) bool has_same_scale_requirement = false; + // Whether this op should have same operand and result type (e.g. gather) + bool has_same_operand_and_result_type_requirement = false; // Returns the fixed output range, when has_fixed_output_range is set. GetFixedOutputRangeFunc fixed_output_range_func; // Returns whether same operands and results scales are required. diff --git a/tensorflow/compiler/mlir/quantization/common/test_base.h b/tensorflow/compiler/mlir/quantization/common/test_base.h index a1a770ff616dee..4564de6c3d5603 100644 --- a/tensorflow/compiler/mlir/quantization/common/test_base.h +++ b/tensorflow/compiler/mlir/quantization/common/test_base.h @@ -62,15 +62,6 @@ class QuantizationTestBase : public Test { return parseSourceString(module_op_str, ctx_.get()); } - // Returns the first operation with the given type in the function. - template - OpType FindOperationOfType(func::FuncOp function) { - for (auto op : function.getBody().getOps()) { - return op; - } - return nullptr; - } - // Convenience function that returns the first operation of type `OpT` from // the `@main` function in `module_op`. Useful when testing with a text // representation of a `ModuleOp` containing a single function `@main`. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 3b53b3c74bb7cb..3da423119752cb 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -8,6 +8,7 @@ load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") package_group( name = "internal_visibility_allowlist_package", packages = [ + "//learning/brain/mlir/quantization/stablehlo/python/integration_test/...", "//tensorflow/compiler/mlir/lite/...", "//tensorflow/compiler/mlir/quantization/...", "//tensorflow/compiler/mlir/tf2xla/transforms/...", @@ -54,6 +55,7 @@ cc_library( "passes/lift_quantizable_spots_as_functions.cc", "passes/lift_quantizable_spots_as_functions_fusion.inc", "passes/lift_quantizable_spots_as_functions_simple.inc", + "passes/merge_fusion_with_dequantize.cc", "passes/nchw_convolution_to_nhwc.cc", "passes/optimize_graph.cc", "passes/post_quantize.cc", @@ -67,6 +69,7 @@ cc_library( "passes/restore_function_name.cc", "passes/unfuse_mhlo_batch_norm.cc", "passes/unwrap_xla_call_module_op.cc", + "passes/xla_call_module_to_call.cc", ], hdrs = [ "passes/passes.h", @@ -95,6 +98,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:permutation", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:report", "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", @@ -111,12 +115,11 @@ cc_library( "//tensorflow/core/ir/types:Dialect", "//tensorflow/core/platform:path", "//tensorflow/core/tpu:tpu_defs", - "//tensorflow/lite/kernels:padding", - "//tensorflow/lite/kernels/internal:quantization_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/random", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -169,10 +172,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:path", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -412,7 +412,7 @@ tf_cc_test( "@local_xla//xla/mlir_hlo", "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/pjrt:pjrt_executable", - "@local_xla//xla/pjrt:tfrt_cpu_pjrt_client", + "@local_xla//xla/pjrt/cpu:cpu_client", "@local_xla//xla/tests:literal_test_util", "@stablehlo//:chlo_ops", ], @@ -528,7 +528,6 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -745,7 +744,9 @@ tf_proto_library( # py_proto_library( # name = "quantization_config_py_pb2", # api_version = 2, -# visibility = [":internal_visibility_allowlist_package"], +# visibility = [ +# ":internal_visibility_allowlist_package", +# ], # deps = [":quantization_config_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD index 77629c7719bf44..5ae92d648bf5c9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -321,7 +321,15 @@ cc_library( hdrs = ["report.h"], compatible_with = get_compatible_with_portable(), deps = [ + "//tensorflow/compiler/mlir/quantization/common:lift_as_function_call", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:protobuf", ], ) @@ -330,8 +338,12 @@ tf_cc_test( srcs = ["report_test.cc"], deps = [ ":report", + "//tensorflow/compiler/mlir/quantization/common:test_base", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD index 5783ffddd4f050..3fbd4ed586e45f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD @@ -25,6 +25,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:min_max_value", "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton", @@ -99,3 +100,20 @@ tf_cc_test( "@local_tsl//tsl/platform:status_matchers", ], ) + +cc_library( + name = "calibration_parameters", + srcs = [], + hdrs = ["calibration_parameters.h"], + compatible_with = get_compatible_with_portable(), + deps = ["//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc"], +) + +tf_cc_test( + name = "calibration_parameters_test", + srcs = ["calibration_parameters_test.cc"], + deps = [ + ":calibration_parameters", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h new file mode 100644 index 00000000000000..ffad37d15d243c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h @@ -0,0 +1,79 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_CALIBRATION_PARAMETERS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_CALIBRATION_PARAMETERS_H_ + +#include +#include +#include + +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace stablehlo::quantization { + +// TODO: b/321158562 - Make the number of bins configurable. +// Default number of histogram bins for each batch sample. +constexpr int32_t kDefaultNumOfBins = 1 << 9; + +// Calculates the bin width from the range and expected number of bins. The +// bin width is formalized to the form of 2^n. As a consequence, the actual +// number of bins might be smaller than the given `num_bins`. +inline float CalculateBinWidth(const float min_value, const float max_value, + const int32_t num_bins) { + const float raw_bin_width = (max_value - min_value) / num_bins; + return std::pow(2, std::ceil(std::log2(raw_bin_width))); +} + +// Calculates the lower bound of the histogram. The lower bound is in form of +// `N * bin_width`. +inline float CalculateLowerBound(const float min_value, const float bin_width) { + return std::floor(min_value / bin_width) * bin_width; +} + +// Calculates the bin index of the current value. +inline int32_t CalculateBinIndex(const float value, const float lower_bound, + const float bin_width) { + return std::floor((value - lower_bound) / bin_width); +} + +// Same as `CalculateBinIndex` but clamps to avoid out-of-bound. +inline int32_t CalculateBinIndexSafe(const float value, const float lower_bound, + const float bin_width, + const int32_t num_bins) { + const int32_t bin_index = CalculateBinIndex(value, lower_bound, bin_width); + return std::clamp(bin_index, 0, num_bins - 1); +} + +// Checks if the given method is a histogram-based calibration method. +inline bool IsHistogramCalibration( + const CalibrationOptions::CalibrationMethod method) { + return method == + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE || + method == + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE || + method == CalibrationOptions:: + CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY || + method == + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC; +} + +// Gets the number of bins for the given calibration method. +inline int32_t GetNumBins(const CalibrationOptions::CalibrationMethod method) { + return IsHistogramCalibration(method) ? kDefaultNumOfBins : 0; +} + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_CALIBRATION_PARAMETERS_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters_test.cc new file mode 100644 index 00000000000000..bff3f5092a8644 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h" + +#include +#include + +#include + +namespace stablehlo::quantization { +namespace { + +// Calculates the number of bins from the range and bin width. +inline int32_t CalculateActualNumBins(const float min_value, + const float max_value, + const float bin_width) { + const float lower_bound = CalculateLowerBound(min_value, bin_width); + return std::ceil((max_value - lower_bound) / bin_width); +} + +TEST(CalibrationParametersTest, CalculateBinWidthSmallerThanOne) { + float bin_width = CalculateBinWidth(/*min_value=*/0.0, /*max_value=*/25.0, + /*num_bins=*/256); + EXPECT_FLOAT_EQ(bin_width, 0.125); + int32_t actual_num_bins = + CalculateActualNumBins(/*min_value=*/0.0, /*max_value=*/25.0, bin_width); + EXPECT_EQ(actual_num_bins, 200); + + // Calculate the bin width with the actual num bins. + float raw_bin_width = 25.0 / actual_num_bins; + EXPECT_FLOAT_EQ(bin_width, raw_bin_width); +} + +TEST(CalibrationParametersTest, CalculateBinWidthLargerThanOne) { + float bin_width = CalculateBinWidth(/*min_value=*/0.0, /*max_value=*/360.0, + /*num_bins=*/256); + EXPECT_FLOAT_EQ(bin_width, 2.0); + int32_t actual_num_bins = + CalculateActualNumBins(/*min_value=*/0.0, /*max_value=*/360.0, bin_width); + EXPECT_EQ(actual_num_bins, 180); + + // Calculate the bin width with the actual num bins. + float raw_bin_width = 360.0 / actual_num_bins; + EXPECT_FLOAT_EQ(bin_width, raw_bin_width); +} + +TEST(CalibrationParametersTest, CalculateBinWidthDivisible) { + float bin_width = CalculateBinWidth(/*min_value=*/0.0, /*max_value=*/256.0, + /*num_bins=*/256); + EXPECT_FLOAT_EQ(bin_width, 1.0); + int32_t actual_num_bins = + CalculateActualNumBins(/*min_value=*/0.0, /*max_value=*/256.0, bin_width); + EXPECT_EQ(actual_num_bins, 256); + + // Calculate the bin width with the actual num bins. + float raw_bin_width = 256.0 / actual_num_bins; + EXPECT_FLOAT_EQ(bin_width, raw_bin_width); +} + +TEST(CalibrationParametersTest, CalculateNumBinsDivisible) { + int32_t num_bins = CalculateActualNumBins( + /*min_value=*/0.0, /*max_value=*/4.0, /*bin_width=*/2.0); + + // Expect 2 bins: [0, 2), [2, 4]. + EXPECT_EQ(num_bins, 2); +} + +TEST(CalibrationParametersTest, CalculateNumBinsNotDivisible) { + int32_t num_bins = CalculateActualNumBins( + /*min_value=*/0.0, /*max_value=*/5.0, /*bin_width=*/2.0); + + // Expect 3 bins: [0, 2), [2, 4), [4, 6]. + EXPECT_EQ(num_bins, 3); +} + +TEST(CalibrationParametersTest, CalculateBinIndex) { + int32_t bin_index = CalculateBinIndexSafe(/*value=*/3.0, /*lower_bound=*/0.0, + /*bin_width=*/2.0, /*num_bins=*/2); + EXPECT_EQ(bin_index, 1); +} + +TEST(CalibrationParametersTest, CalculateBinIndexMaxValue) { + int32_t bin_index = CalculateBinIndexSafe(/*value=*/4.0, /*lower_bound=*/0.0, + /*bin_width=*/2.0, /*num_bins=*/2); + EXPECT_EQ(bin_index, 1); +} + +} // namespace +} // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc index ba1671ceb696ca..ce626145318b9f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h" +#include #include #include #include @@ -122,7 +123,7 @@ absl::StatusOr CalibrationComponent::Run( // Translates `RepresentativeDatasetConfig`s to signature key -> // `RepresentativeDatasetFile` mapping. const auto dataset_configs = - config.static_range_ptq_preset().representative_datasets(); + config.calibration_options().representative_datasets(); const std::vector dataset_config_vector( dataset_configs.begin(), dataset_configs.end()); TF_ASSIGN_OR_RETURN( @@ -132,10 +133,13 @@ absl::StatusOr CalibrationComponent::Run( // Runs calibration on the exported model. The statistics will be stored in a // separate singleton object `CalibratorSingleton` and are directly added to // `exported_model` without re-importing it. - py_function_lib_->RunCalibration( - precalibrated_saved_model_dir, signature_keys_, tags_, - config.calibration_options(), - /*force_graph_mode_calibration=*/true, representative_dataset_file_map); + if (py_function_lib_->RunCalibration( + precalibrated_saved_model_dir, signature_keys_, tags_, + /*force_graph_mode_calibration=*/true, + representative_dataset_file_map) == std::nullopt) { + return absl::InternalError( + "CalibrationComponent error: Failed to run calibration."); + } if (absl::Status status = AddCalibrationStatistics( module_op, config.calibration_options(), *py_function_lib_); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc index 39f4ca8449ae05..19a44097458f1a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc @@ -15,13 +15,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h" #include -#include #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/min_max_value.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" @@ -55,11 +55,18 @@ absl::Status AddCalibrationStatistics( return; } - const auto [min_value, max_value] = + const std::optional min_max_values = py_function_library.GetCalibrationMinMaxValue(*statistics, calibration_options); CalibratorSingleton::ClearData(id); + if (min_max_values == std::nullopt) { + status = absl::InternalError( + "Cannot find min/max values for calibration statistics."); + return; + } + + const auto [min_value, max_value] = *min_max_values; mlir::OpBuilder builder(aggregator_op); aggregator_op->setAttr("min", builder.getF32FloatAttr(min_value)); aggregator_op->setAttr("max", builder.getF32FloatAttr(max_value)); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc index 0f9932d053cb4d..b3aa1500a0a3c7 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc @@ -102,12 +102,21 @@ QuantizationSpec GetDefaultStaticRangePtqSpec(StaticRangePtqPreset preset) { QuantizationSpec spec{}; // Default for all ops. spec.mutable_matcher()->mutable_function_name()->set_regex( - preset.enable_full_int_quantization() ? ".*" : "^.*(conv|dot|gather).*"); + preset.enable_full_int_quantization() ? ".*" + : "^.*(dot_general|gather).*"); spec.mutable_method()->mutable_static_range_ptq(); return spec; } +QuantizationSpec GetDefaultWeightOnlyPtqSpec(WeightOnlyPtqPreset preset) { + QuantizationSpec spec{}; + spec.mutable_matcher()->mutable_function_name()->set_regex( + "^.*(conv|dot_general).*"); + spec.mutable_method()->mutable_weight_only_ptq(); + return spec; +} + // Returns a `QuantizationSpec` for performing static-range PTQ on the // convolution quantizable unit family. Enables per-channel quantization for // weights, on the channel dimension. @@ -122,14 +131,12 @@ QuantizationSpec GetDefaultStaticRangePtqSpec(StaticRangePtqPreset preset) { // value {dimension_specs {dimension: 3}}}} // }} // } -QuantizationSpec GetStaticRangePtqSpecForConvolution() { +QuantizationSpec GetPtqSpecForConvolution(Method::MethodCase method_case) { QuantizationSpec spec{}; // Matches all convolution quantizable unit family. spec.mutable_matcher()->mutable_function_name()->set_regex( "composite_conv.*"); - StaticRangePtq& static_range_ptq_spec = - *spec.mutable_method()->mutable_static_range_ptq(); // Enable per-channel quantization for convolution weights. QuantizedType conv_weight_quantized_type{}; @@ -140,8 +147,17 @@ QuantizationSpec GetStaticRangePtqSpecForConvolution() { // The index of weight operands passed to lifted functions for convolution // is 1. - static_range_ptq_spec.mutable_input_quantized_types()->try_emplace( - 1, std::move(conv_weight_quantized_type)); + if (method_case == Method::kStaticRangePtq) { + StaticRangePtq& static_range_ptq_spec = + *spec.mutable_method()->mutable_static_range_ptq(); + static_range_ptq_spec.mutable_input_quantized_types()->try_emplace( + 1, std::move(conv_weight_quantized_type)); + } else if (method_case == Method::kWeightOnlyPtq) { + WeightOnlyPtq& weight_only_ptq_spec = + *spec.mutable_method()->mutable_weight_only_ptq(); + weight_only_ptq_spec.mutable_input_quantized_types()->try_emplace( + 1, std::move(conv_weight_quantized_type)); + } return spec; }; @@ -164,13 +180,34 @@ void ExpandStaticRangePtqPreset(const StaticRangePtqPreset& preset, QuantizationSpecs new_specs{}; *new_specs.add_specs() = GetDefaultStaticRangePtqSpec(/*preset=*/config.static_range_ptq_preset()); - *new_specs.add_specs() = GetStaticRangePtqSpecForConvolution(); + *new_specs.add_specs() = + GetPtqSpecForConvolution(Method::MethodCase::kStaticRangePtq); + + // Append user-provided specs to override existing specs. + const QuantizationSpecs& previous_specs = config.specs(); + new_specs.mutable_specs()->Add(previous_specs.specs().begin(), + previous_specs.specs().end()); + + config.clear_static_range_ptq_preset(); + config.mutable_specs()->Swap(&new_specs); +} + +void ExpandWeightOnlyPtqPreset(const WeightOnlyPtqPreset& preset, + QuantizationConfig& config) { + // Create a new `QuantizationSpecs` to replace the existing one. The + // expansion from `WeightOnlyPtqPreset` gets populated first and then + // user-provided explicit `QuantizationSpec`s will be appended. + QuantizationSpecs new_specs{}; + *new_specs.add_specs() = + GetDefaultWeightOnlyPtqSpec(/*preset=*/config.weight_only_ptq_preset()); + // TODO: b/307625297 - Add per-channel weight only support. // Append user-provided specs to override existing specs. const QuantizationSpecs& previous_specs = config.specs(); new_specs.mutable_specs()->Add(previous_specs.specs().begin(), previous_specs.specs().end()); + config.clear_weight_only_ptq_preset(); config.mutable_specs()->Swap(&new_specs); } @@ -184,6 +221,9 @@ QuantizationConfig ExpandPresets(const QuantizationConfig& config) { case QuantizationConfig::kStaticRangePtqPreset: ExpandStaticRangePtqPreset(config.static_range_ptq_preset(), new_config); break; + case QuantizationConfig::kWeightOnlyPtqPreset: + ExpandWeightOnlyPtqPreset(config.weight_only_ptq_preset(), new_config); + break; default: // Preset has not been specified. The expansion is a no-op. break; @@ -192,6 +232,16 @@ QuantizationConfig ExpandPresets(const QuantizationConfig& config) { return new_config; } +bool HasQuantizationMethod(const QuantizationSpecs& specs, + Method::MethodCase method_case) { + for (const auto& spec : specs.specs()) { + if (spec.method().method_case() == method_case) { + return true; + } + } + return false; +} + QuantizationConfig PopulateDefaults( const QuantizationConfig& user_provided_config) { QuantizationConfig config = user_provided_config; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h index 5dc4554d784c92..19f250bedfe1b8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h @@ -41,6 +41,10 @@ QuantizationConfig PopulateDefaults( // - No-op. QuantizationConfig ExpandPresets(const QuantizationConfig& config); +// Returns whether a given QuantizationSpecs has the given quantization method. +bool HasQuantizationMethod(const QuantizationSpecs& specs, + Method::MethodCase method_case); + } // namespace stablehlo::quantization #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc index e3f2bfde3d10c3..c46daaf1252f26 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc @@ -198,7 +198,7 @@ TEST(ExpandPresetsTest, ExpandStaticRangePtqPresetDefault) { const QuantizationSpec& spec = new_config.specs().specs(0); EXPECT_THAT(spec.matcher().function_name().regex(), - StrEq("^.*(conv|dot|gather).*")); + StrEq("^.*(dot_general|gather).*")); EXPECT_TRUE(spec.method().has_static_range_ptq()); } @@ -274,5 +274,18 @@ TEST(ExpandPresetsTest, ExpandStaticRangePtqPresetThenAppendExplicitSpecs) { EXPECT_TRUE(third_spec.method().has_no_quantization()); } +TEST(ExpandPresetsTest, ExpandWeightOnlyPtqPresetDefault) { + QuantizationConfig config{}; + *config.mutable_weight_only_ptq_preset() = WeightOnlyPtqPreset(); + + const QuantizationConfig new_config = ExpandPresets(config); + ASSERT_THAT(new_config.specs().specs(), SizeIs(1)); + + const QuantizationSpec& spec = new_config.specs().specs(0); + EXPECT_THAT(spec.matcher().function_name().regex(), + StrEq("^.*(conv|dot_general).*")); + EXPECT_TRUE(spec.method().has_weight_only_ptq()); +} + } // namespace } // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc index ebe950c58142f6..622ff502c01ed9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc @@ -54,18 +54,19 @@ void AddPreCalibrationPasses(OpPassManager& pm, pm.addPass(CreateIssueIDsOfCustomAggregationOpsPass()); } -void AddPostCalibrationPasses( - OpPassManager& pm, const PipelineConfig& pipeline_config, - const StaticRangePtqPreset& static_range_ptq_preset) { +void AddPostCalibrationPasses(OpPassManager& pm, + const PipelineConfig& pipeline_config, + const QuantizationSpecs& specs) { QuantizeCompositeFunctionsPassOptions options; - // TODO: b/331120943 - Use QuantizationConfig instead of preset flags. - options.enable_per_channel_quantized_weight_ = - static_range_ptq_preset.enable_per_channel_quantized_weight(); - options.enable_full_int_quantization_ = - static_range_ptq_preset.enable_full_int_quantization(); + // TODO: b/331120943 - Temporarily set below to true, signaling per-channel + // quantization will be applied for all where applicable. This will be + // replaced by individual `Method` in `QuantizationSpecs`. + options.enable_per_channel_quantized_weight_ = true; // For debugging purposes. options.mlir_dump_file_name_ = "quantize_composite_functions"; options.enable_weight_only_ = false; + options.merge_fusion_with_dequantize_ = + pipeline_config.merge_fusion_with_dequantize(); AddShapeLegalizationPasses(pm); pm.addNestedPass( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h index 4f94506b6c184e..408152f6fc5a49 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h @@ -34,8 +34,7 @@ void AddPreCalibrationPasses( void AddPostCalibrationPasses( OpPassManager& pm, const ::stablehlo::quantization::PipelineConfig& pipeline_config, - const ::stablehlo::quantization::StaticRangePtqPreset& - static_range_ptq_preset); + const ::stablehlo::quantization::QuantizationSpecs& specs); // Adds passes for weight-only quantization. void AddWeightOnlyQuantizationPasses( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc index 6f5f10b48f41f5..001ece707cfe90 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc @@ -29,7 +29,7 @@ namespace mlir::quant::stablehlo { using ::stablehlo::quantization::PipelineConfig; using ::stablehlo::quantization::QuantizationConfig; -using ::stablehlo::quantization::StaticRangePtqPreset; +using ::stablehlo::quantization::QuantizationSpecs; using ::tensorflow::quantization::RunPasses; PostCalibrationComponent::PostCalibrationComponent( @@ -40,18 +40,17 @@ absl::StatusOr PostCalibrationComponent::Run( ModuleOp module_op, const QuantizationConfig& config) { TF_RETURN_IF_ERROR(RunPasses( kName, /*add_passes_func=*/ - [&config, this](PassManager& pm) { - AddPostCalibrationPasses(pm, config.pipeline_config(), - config.static_range_ptq_preset()); + [&config](PassManager& pm) { + AddPostCalibrationPasses(pm, config.pipeline_config(), config.specs()); }, *ctx_, module_op)); return module_op; } void PostCalibrationComponent::AddPasses( - OpPassManager& pm, const StaticRangePtqPreset& static_range_ptq_preset, + OpPassManager& pm, const QuantizationSpecs& specs, const PipelineConfig& pipeline_config) const { - AddPostCalibrationPasses(pm, pipeline_config, static_range_ptq_preset); + AddPostCalibrationPasses(pm, pipeline_config, specs); } } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h index 3c218c9f857524..6e3762817e16a1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h @@ -47,8 +47,7 @@ class PostCalibrationComponent : public Component { void AddPasses( OpPassManager& pm, - const ::stablehlo::quantization::StaticRangePtqPreset& - static_range_ptq_preset, + const ::stablehlo::quantization::QuantizationSpecs& specs, const ::stablehlo::quantization::PipelineConfig& pipeline_config) const; private: diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc index a423bdc5f80142..6143b21eec32cd 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc @@ -38,7 +38,7 @@ absl::StatusOr PreCalibrationComponent::Run( ModuleOp module_op, const QuantizationConfig& config) { TF_RETURN_IF_ERROR(RunPasses( kName, /*add_passes_func=*/ - [&config, this](PassManager& pm) { + [&config](PassManager& pm) { AddPreCalibrationPasses(pm, config.calibration_options(), config.specs(), config.debugger_config()); }, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc index ef24c16dbf4acc..93be3516d76f8d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc @@ -14,16 +14,142 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h" +#include +#include #include +#include "absl/strings/str_cat.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep namespace mlir::quant::stablehlo { +namespace { using ::stablehlo::quantization::QuantizationResult; +using ::stablehlo::quantization::QuantizationResults; +using ::tsl::protobuf::TextFormat; + +// Given a `quantized_func_name` that starts with `kQuantizedFuncPrefix`, +// converts `kQuantizedFuncPrefix` to `kCompositeFuncPrefix`. +std::string GetCompositeFunctionName(const StringRef quantized_func_name) { + return Twine(kCompositeFuncPrefix) + .concat(quantized_func_name.rsplit(kQuantizedFuncPrefix).second) + .str(); +} + +// Retrieves `QuantizationResult` from `call_op`. If the callee's name starts +// with `kQuantizedFuncPrefix` then a `QuantizationResult` will be returned with +// its `name` field set to the callee's name reverted back to the lifted +// function's name. Otherwise, returns `std::nullopt`. +std::optional GetQuantizationResult(func::CallOp call_op) { + const StringRef callee_name = call_op.getCalleeAttr().getValue(); + + if (callee_name.starts_with(kQuantizedFuncPrefix)) { + // TODO: b/329554870 - Transfer the `Method` used to quantize the op. + QuantizationResult result{}; + result.mutable_quantizable_unit()->set_name( + GetCompositeFunctionName(callee_name)); + return result; + } else { + return std::nullopt; + } +} + +// Retrieves `QuantizationResult` from `xla_call_module_op`. If +// `xla_call_module_op` is a quantizable unit, then a `QuantizationResult` will +// be returned with its `name` field set to the callee's name. The `method` +// field will be set to `NoQuantization` because remaining `xla_call_module_op`s +// means they are not quantized. Returns `std::nullopt` if `xla_call_module_op` +// is not a quantizable unit. +std::optional GetQuantizationResult( + TF::XlaCallModuleOp xla_call_module_op) { + const StringAttr callee_name_attr = + xla_call_module_op + ->getDiscardableAttr(kOriginalStablehloEntryFunctionAttrName) + .dyn_cast_or_null(); + + // `TF::XlaCallModuleOp` without the `_original_entry_function` means it is + // not a quantizable unit. + if (callee_name_attr == nullptr) return std::nullopt; + + if (callee_name_attr.getValue().starts_with(kCompositeFuncPrefix)) { + QuantizationResult result{}; + result.mutable_quantizable_unit()->set_name( + callee_name_attr.getValue().str()); + result.mutable_method()->mutable_no_quantization(); + return result; + } else { + return std::nullopt; + } +} + +// Populates quantized ops from `module_op` to `results`. After going through +// the quantization passes, quantized ops are represented as `func::CallOp` with +// a callee's prefix of `quantized_`. +void PopulateQuantizedResults(ModuleOp module_op, + QuantizationResults& results) { + module_op.walk([&results](func::CallOp call_op) { + std::optional result = GetQuantizationResult(call_op); + if (result == std::nullopt) return WalkResult::skip(); + + *results.add_results() = std::move(*result); + return WalkResult::advance(); + }); +} + +// Populates non-quantized ops from `module_op` to `results`. After going +// through the quantization passes, non-quantized quantizable units remain as +// `TF::XlaCallModuleOp` with a callee's prefix of `composite_`. +void PopulateNonQuantizedResults(ModuleOp module_op, + QuantizationResults& results) { + module_op.walk([&results](TF::XlaCallModuleOp xla_call_module_op) { + std::optional result = + GetQuantizationResult(xla_call_module_op); + if (result == std::nullopt) return WalkResult::skip(); + + *results.add_results() = std::move(*result); + return WalkResult::advance(); + }); +} + +} // namespace + +QuantizationReport::QuantizationReport(ModuleOp module_op) + : quantization_results_(CollectResultsFromModuleOp(module_op)) {} + +QuantizationResults QuantizationReport::CollectResultsFromModuleOp( + ModuleOp module_op) const { + QuantizationResults results{}; + + PopulateQuantizedResults(module_op, results); + PopulateNonQuantizedResults(module_op, results); + + return results; +} void QuantizationReport::AddQuantizationResult(QuantizationResult&& result) { *quantization_results_.add_results() = std::move(result); } +std::string QuantizationReport::ToString() const { + std::string results_str{}; + TextFormat::PrintToString(quantization_results_, &results_str); + + return absl::StrCat("===== Quantization Report =====\n\n", results_str, + "\n===== Quantization Report End =====\n\n"); +} + +void QuantizationReport::Print() const { + llvm::outs() << ToString(); + llvm::outs().flush(); // Show the report immediately. +} + } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h index 94eb47463f16c1..a362bb758cb60c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_REPORT_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_REPORT_H_ +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" namespace mlir::quant::stablehlo { @@ -27,6 +30,10 @@ class QuantizationReport { public: QuantizationReport() = default; + // Initializes `QuantizationReport` by collecting `QuantizationResults` from + // `module_op`. + explicit QuantizationReport(ModuleOp module_op); + // Adds a `QuantizationResult` to the report. void AddQuantizationResult( ::stablehlo::quantization::QuantizationResult&& result); @@ -37,7 +44,16 @@ class QuantizationReport { return quantization_results_; } + // Returns a human-readable string representation of this report. + std::string ToString() const; + + // Prints a human-readable report to stdout. + void Print() const; + private: + ::stablehlo::quantization::QuantizationResults CollectResultsFromModuleOp( + ModuleOp module_op) const; + // Quantization results that are registered in this report. A quantization // result may be added manually by calling `AddQuantizationResult`. ::stablehlo::quantization::QuantizationResults quantization_results_; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc index f6897f7fde401d..4783fb6beebc2d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc @@ -14,11 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h" +#include #include #include #include +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/test_base.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep namespace mlir::quant::stablehlo { namespace { @@ -30,15 +36,18 @@ using ::stablehlo::quantization::QuantizationResults; using ::testing::IsEmpty; using ::testing::SizeIs; using ::testing::StrEq; +using ::tsl::protobuf::TextFormat; -TEST(QuantizationReportTest, GetQuantizationResultsReturnsEmptyResults) { +using QuantizationReportTest = ::mlir::quant::QuantizationTestBase; + +TEST_F(QuantizationReportTest, GetQuantizationResultsReturnsEmptyResults) { QuantizationReport report{}; const QuantizationResults& results = report.GetQuantizationResults(); ASSERT_THAT(results.results(), IsEmpty()); } -TEST(QuantizationReportTest, AddQuantizationResult) { +TEST_F(QuantizationReportTest, AddQuantizationResult) { // Construct a `QuantizationResult` to add, representing a unit named // `quantized_my_function` that is not quantized. QuantizationResult result{}; @@ -60,5 +69,144 @@ TEST(QuantizationReportTest, AddQuantizationResult) { EXPECT_TRUE(first_result.method().has_no_quantization()); } +TEST_F(QuantizationReportTest, InitializeWithModuleOp) { + constexpr absl::string_view kQuantizedDotGeneral = R"mlir( + func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %2 = call @quantized_dot_general_fn(%1, %0) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + func.func private @quantized_dot_general_fn(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kQuantizedDotGeneral); + ASSERT_TRUE(module_op); + + const QuantizationReport report(*module_op); + const QuantizationResults& results = report.GetQuantizationResults(); + ASSERT_THAT(results.results(), SizeIs(1)); + + // Test that the quantized `QuantizableUnit` corresponding to + // `composite_dot_general_fn` is captured. + // TODO: Transfer the `Method` used to quantize the op. + const QuantizationResult& result = results.results(0); + EXPECT_THAT(result.quantizable_unit().name(), + StrEq("composite_dot_general_fn")); + EXPECT_FALSE(result.has_method()); +} + +TEST_F(QuantizationReportTest, InitializeWithModuleOpWithNonQuantizedOp) { + constexpr absl::string_view kNonQuantizedDotGeneral = R"mlir( + func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant dense<3.000000e+0> : tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kNonQuantizedDotGeneral); + ASSERT_TRUE(module_op); + + const QuantizationReport report(*module_op); + const QuantizationResults& results = report.GetQuantizationResults(); + ASSERT_THAT(results.results(), SizeIs(1)); + + // Test that the unquantized `QuantizableUnit` corresponding to + // `composite_dot_general_fn` is captured. The `Method` contains + // `NoQuantization`. + const QuantizationResult& result = results.results(0); + EXPECT_THAT(result.quantizable_unit().name(), + StrEq("composite_dot_general_fn")); + EXPECT_TRUE(result.method().has_no_quantization()); +} + +TEST_F(QuantizationReportTest, + InitializeWithModuleOpWithQuantizedAndNonQuantizedOps) { + constexpr absl::string_view kQuantizedDotGeneralAndNonQuantizedDotGeneral = + R"mlir( + func.func @main(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x3xf32> { + // Non-quantized dot_general. + %0 = stablehlo.constant dense<3.000000e+0> : tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + // Quantized dot_general. + %2 = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>> + %3 = stablehlo.uniform_quantize %arg1 : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %4 = call @quantized_dot_general_fn_2(%3, %2) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> + %5 = stablehlo.uniform_dequantize %4 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // Add is there to prevent from dot_generals from being DCEed. + %6 = stablehlo.add %1, %5 : tensor<1x3xf32> + return %6 : tensor<1x3xf32> + } + + // Callee of non-quantized op. + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + + // Callee of quantized op. + func.func private @quantized_dot_general_fn_2(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {1.000000e+0,2.000000e+0,3.000000e+0}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kQuantizedDotGeneralAndNonQuantizedDotGeneral); + ASSERT_TRUE(module_op); + + const QuantizationReport report(*module_op); + const QuantizationResults& results = report.GetQuantizationResults(); + ASSERT_THAT(results.results(), SizeIs(2)); + + // Test that the quantized op is captured in `results`. + const QuantizationResult& quantized_result = results.results(0); + EXPECT_THAT(quantized_result.quantizable_unit().name(), + StrEq("composite_dot_general_fn_2")); + EXPECT_FALSE(quantized_result.has_method()); + + // Test that the non-quantized op is captured in `results`. + const QuantizationResult& non_quantized_result = results.results(1); + EXPECT_THAT(non_quantized_result.quantizable_unit().name(), + StrEq("composite_dot_general_fn_1")); + EXPECT_TRUE(non_quantized_result.method().has_no_quantization()); +} + +TEST_F(QuantizationReportTest, ToString) { + QuantizationResult result{}; + QuantizableUnit& quantizable_unit = *result.mutable_quantizable_unit(); + quantizable_unit.set_name("quantized_my_function"); + + Method& method = *result.mutable_method(); + method.mutable_no_quantization(); + + QuantizationReport report{}; + report.AddQuantizationResult(std::move(result)); + + // Check that the report string is equivalent to the textproto representation + // of the `QuantizationResults`. + std::string result_str{}; + TextFormat::PrintToString(report.GetQuantizationResults(), &result_str); + + EXPECT_THAT(report.ToString(), testing::HasSubstr("Quantization Report")); + EXPECT_THAT(report.ToString(), testing::HasSubstr(result_str)); + EXPECT_THAT(report.ToString(), testing::HasSubstr("Quantization Report End")); +} + } // namespace } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD index 35584857f5761f..61da2af4d3fb58 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD @@ -40,6 +40,8 @@ tf_cc_test( deps = [ ":stablehlo_op_quant_spec", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:func", "//tensorflow/compiler/mlir/quantization/common:test_base", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/tensorflow", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc index c78ee607993385..3018db7b2649e9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc @@ -131,7 +131,7 @@ std::unique_ptr GetStableHloOpQuantSpec(Operation* op) { return spec; } -std::unique_ptr GetStableHloQuantScaleSpec(Operation* op) { +std::unique_ptr GetStableHloQuantConstraints(Operation* op) { auto scale_spec = std::make_unique(); if (llvm::isa GetStableHloQuantScaleSpec(Operation* op) { mlir::stablehlo::SliceOp, mlir::stablehlo::TransposeOp>(op)) { scale_spec->has_same_scale_requirement = true; } + if (llvm::isa(op)) { + scale_spec->has_same_operand_and_result_type_requirement = true; + } return scale_spec; } @@ -165,7 +169,7 @@ bool IsOpQuantizableStableHlo(Operation* op) { return false; } - if (GetStableHloQuantScaleSpec(op)->has_same_scale_requirement) { + if (GetStableHloQuantConstraints(op)->has_same_scale_requirement) { return true; } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h index 6edeb9829b6b63..6c688e823c96ba 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h @@ -28,9 +28,9 @@ namespace mlir::quant::stablehlo { // Returns StableHLO quantization specs for an op. std::unique_ptr GetStableHloOpQuantSpec(Operation* op); -// Returns quantization scale specs (fixed output, same scale) for a StableHLO -// op. -std::unique_ptr GetStableHloQuantScaleSpec(Operation* op); +// Returns quantization constraints (ex: fixed output, same scale) given +// a StableHLO op. +std::unique_ptr GetStableHloQuantConstraints(Operation* op); // Checks if an op is quantizable in StableHLO quantizer. Argument op is not // necessarily a StableHLO op. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec_test.cc index b3ba4818284498..572bf0e05729b0 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec_test.cc @@ -26,6 +26,8 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/common/test_base.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -34,7 +36,9 @@ limitations under the License. namespace mlir::quant::stablehlo { namespace { +using ::mlir::stablehlo::GatherOp; using ::testing::IsEmpty; +using ::testing::IsTrue; using ::testing::NotNull; using ::testing::Pair; using ::testing::UnorderedElementsAre; @@ -284,5 +288,42 @@ TEST_F(GetStableHloOpQuantSpecTest, UnorderedElementsAre(Pair(1, 3))); } +using GetStableHloQuantConstraintsTest = ::mlir::quant::QuantizationTestBase; + +TEST_F(GetStableHloQuantConstraintsTest, + HasSameOperandAndResultTypeRequirementSucceeds) { + // Quantizable ops: constants + // Non-quantizable ops: normal StableHLO ops and terminators + constexpr absl::string_view kModuleGather = R"mlir( + module { + func.func @main() -> (tensor<2x3x2x2xf32>) { + %0 = stablehlo.constant dense<1.0> : tensor<3x4x2xf32> + %1 = stablehlo.constant dense<2> : tensor<2x3x2xi64> + %2 = "stablehlo.gather"(%0, %1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xf32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> + func.return %2 : tensor<2x3x2x2xf32> + } + } + )mlir"; + OwningOpRef module_op = ParseModuleOpString(kModuleGather); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + Operation* gather_op = FindOperationOfType(main_fn); + const auto spec = GetStableHloQuantConstraints(gather_op); + + EXPECT_THAT(spec, NotNull()); + EXPECT_THAT(spec->has_same_operand_and_result_type_requirement, IsTrue()); +} + } // namespace } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc index 1987b607392379..ad0179f3c051a1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc @@ -48,9 +48,9 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/literal_test_util.h" diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc index a4bf42ec6f8eba..6577666ab90f10 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc @@ -177,9 +177,6 @@ FailureOr QuantizationMethodToTextProto(const Method& method) { // TODO: b/307620778 - Support more advanced selective quantization methods. LogicalResult ApplyQuantizationSpec(const QuantizationSpec& spec, ModuleOp module_op) { - func::FuncOp main_func = FindMainFuncOp(module_op); - if (!main_func) return failure(); - const Method& quantization_method = spec.method(); FailureOr quantization_method_txtpb = @@ -187,14 +184,18 @@ LogicalResult ApplyQuantizationSpec(const QuantizationSpec& spec, if (failed(quantization_method_txtpb)) return failure(); const FunctionNameMatcher matcher(spec.matcher().function_name()); - for (auto xla_call_module_op : main_func.getOps()) { - if (!matcher.Match(xla_call_module_op)) continue; - - // Set the text representation of `Method` to matched `TF::XlaCallModuleOp`. - xla_call_module_op->setAttr( - kQuantizationMethodAttr, - StringAttr::get(module_op.getContext(), - std::move(*quantization_method_txtpb))); + // Iterate over all XlaCallModuleOp in all FuncOps. + for (auto func : module_op.getOps()) { + for (auto xla_call_module_op : func.getOps()) { + if (!matcher.Match(xla_call_module_op)) continue; + + // Set the text representation of `Method` to matched + // `TF::XlaCallModuleOp`. + xla_call_module_op->setAttr( + kQuantizationMethodAttr, + StringAttr::get(module_op.getContext(), + std::move(*quantization_method_txtpb))); + } } return success(); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td index 6377740bf6018e..75940a24cf484f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td @@ -40,6 +40,28 @@ def LiftDotGeneralWithBiasSameShape : Pat< (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias)], [], (addBenefit 5)>; +def LiftConvWithBiasSameShape : Pat< + (StableHLO_AddOp:$res + (StableHLO_ConvolutionOp $lhs, $rhs, $window_strides, $padding, + $lhs_dilation, $rhs_dilation, $window_reversal, $dimension_numbers, + $feature_group_count, $batch_group_count, $precision_config), + $bias), + (LiftAsTFXlaCallModule<"composite_conv_with_bias_same_shape_fn"> + (ArgumentList $lhs, $rhs, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"window_strides"> (DefaultOrNullAttr $window_strides)), + (NamedAttr<"padding"> (DefaultOrNullAttr $padding)), + (NamedAttr<"lhs_dilation"> (DefaultOrNullAttr $lhs_dilation)), + (NamedAttr<"rhs_dilation"> (DefaultOrNullAttr $rhs_dilation)), + (NamedAttr<"window_reversal"> (DefaultOrNullAttr $window_reversal)), + (NamedAttr<"dimension_numbers"> $dimension_numbers), + (NamedAttr<"feature_group_count"> $feature_group_count), + (NamedAttr<"batch_group_count"> $batch_group_count), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias)], [], (addBenefit 5)>; + + def LiftConvWithBias : Pat< (StableHLO_AddOp:$res (StableHLO_ConvolutionOp $lhs, $rhs, $window_strides, $padding, @@ -245,6 +267,31 @@ def LiftDotGeneralWithBiasSameShapeAndRelu : Pat< [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias)], [], (addBenefit 10)>; +def LiftConvWithBiasSameShapeAndRelu : Pat< + (StableHLO_MaxOp:$res + (StableHLO_AddOp + (StableHLO_ConvolutionOp $lhs, $rhs, $window_strides, $padding, + $lhs_dilation, $rhs_dilation, $window_reversal, $dimension_numbers, + $feature_group_count, $batch_group_count, $precision_config), + $bias), + (StableHLO_ConstantOp $cst)), + (LiftAsTFXlaCallModule<"composite_conv_with_bias_same_shape_and_relu_fn"> + (ArgumentList $lhs, $rhs, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"window_strides"> (DefaultOrNullAttr $window_strides)), + (NamedAttr<"padding"> (DefaultOrNullAttr $padding)), + (NamedAttr<"lhs_dilation"> (DefaultOrNullAttr $lhs_dilation)), + (NamedAttr<"rhs_dilation"> (DefaultOrNullAttr $rhs_dilation)), + (NamedAttr<"window_reversal"> (DefaultOrNullAttr $window_reversal)), + (NamedAttr<"dimension_numbers"> $dimension_numbers), + (NamedAttr<"feature_group_count"> $feature_group_count), + (NamedAttr<"batch_group_count"> $batch_group_count), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + [(IsNotInLiftedFunc $res), + (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias)], [], (addBenefit 10)>; + + def LiftConvWithBiasAndRelu : Pat< (StableHLO_MaxOp:$res (StableHLO_AddOp diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc new file mode 100644 index 00000000000000..acfe3cfd6fc6b2 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc @@ -0,0 +1,145 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" + +namespace mlir::quant::stablehlo { + +#define GEN_PASS_DEF_MERGEFUSIONWITHDEQUANTIZEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" + +namespace { + +class MergeFusionWithDequantizePass + : public impl::MergeFusionWithDequantizePassBase< + MergeFusionWithDequantizePass> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeFusionWithDequantizePass) + + explicit MergeFusionWithDequantizePass() = default; + + private: + void runOnOperation() override; +}; + +class MergeFusionWithUniformDequantizePattern + : public OpRewritePattern { + public: + explicit MergeFusionWithUniformDequantizePattern(MLIRContext* context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(func::CallOp call_op, + PatternRewriter& rewriter) const override { + if (call_op.getNumResults() != 1) return failure(); + auto users = call_op->getUsers(); + for (auto user : users) { + if (!llvm::isa(user)) { + return failure(); + } + } + auto func_name = call_op.getCallee(); + if (!func_name.starts_with("quantized_")) return failure(); + if (call_op->getNumResults() != 1) return failure(); + if (!getElementTypeOrSelf(call_op->getResult(0).getType()) + .isa()) + return failure(); + + // Fetch the callee function. + SymbolTable symbol_table(call_op->getParentOfType()); + auto func_op = + dyn_cast_or_null(symbol_table.lookup(func_name)); + if (!func_op) return failure(); + // The quantized fusion should have requantize and return ops at the end. + auto return_op = dyn_cast_or_null( + func_op.getRegion().getBlocks().front().getTerminator()); + if (!return_op) return failure(); + auto req_op = llvm::dyn_cast_or_null( + return_op.getOperands()[0].getDefiningOp()); + if (!req_op) return failure(); + + // Create a new func.call op with f32 output. + auto new_call_op = call_op.clone(); + new_call_op->getResult(0).setType( + call_op.getResult(0).getType().cast().clone( + rewriter.getF32Type())); + rewriter.setInsertionPoint(call_op); + rewriter.insert(new_call_op); + + // Remove the dequantize ops and replace uses by the new func.call op. + SmallVector users_to_erase; + for (auto user : users) { + llvm::dyn_cast(user) + .replaceAllUsesWith(new_call_op.getResult(0)); + users_to_erase.push_back(user); + } + for (auto user : users_to_erase) rewriter.eraseOp(user); + rewriter.eraseOp(call_op); + func_op.eraseResult(0); + func_op.insertResult(0, new_call_op.getResult(0).getType(), + /*resultAttrs=*/nullptr); + + // Modify the quantized fused function to do dequantize+relu(6). + rewriter.setInsertionPoint(req_op); + Value new_result = rewriter.create( + req_op.getLoc(), func_op.getResultTypes()[0], req_op.getOperand()); + if (func_name.contains("_relu6_")) { + auto min = rewriter.create( + req_op.getLoc(), rewriter.getF32FloatAttr(0)); + auto max = rewriter.create( + req_op.getLoc(), rewriter.getF32FloatAttr(6)); + new_result = rewriter.create( + req_op.getLoc(), min, new_result, max); + } else if (func_name.contains("_relu_")) { + auto min = rewriter.create( + req_op.getLoc(), rewriter.getF32FloatAttr(0)); + new_result = rewriter.create( + req_op.getLoc(), min, new_result, nullptr); + } + return_op->setOperand(0, new_result); + rewriter.eraseOp(req_op); + + return success(); + } +}; + +void MergeFusionWithDequantizePass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext* ctx = module_op.getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index 63f6f822dbebdf..fdb7fa7941f025 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -60,10 +60,6 @@ def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-function "enable-per-channel-quantized-weight", "bool", /*default=*/"true", "Whether to enable per-channel quantized weights.">, - Option<"enable_full_int_quantization_", - "enable-full-int-quantization", - "bool", /*default=*/"false", - "Whether to enable full int quantization, including non compute-heavy ops.">, Option<"mlir_dump_file_name_", "mlir-dump-file-name", "std::optional", /*default=*/"std::nullopt", "MLIR dump file name.">, @@ -71,6 +67,10 @@ def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-function "enable-weight-only", "bool", /*default=*/"false", "Whether to produce weight-only quantized op for convolution and dot_general op.">, + Option<"merge_fusion_with_dequantize_", + "merge-fusion-with-dequantize", + "bool", /*default=*/"false", + "Whether to merge quantized conv/dot_general fusion with subsequent dequantize.">, ]; let dependentDialects = [ "mlir::arith::ArithDialect", @@ -106,10 +106,6 @@ def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> { "enable-per-channel-quantized-weight", "bool", /*default=*/"true", "Whether to enable per-channel quantized weights.">, - Option<"enable_full_int_quantization_", - "enable-full-int-quantization", - "bool", /*default=*/"false", - "Whether to apply full int quantization, including non compute-heavy ops.">, Option<"enable_weight_only_", "enable-weight-only", "bool", /*default=*/"false", @@ -130,6 +126,21 @@ def PostQuantizePass : Pass<"stablehlo-post-quantize", "mlir::func::FuncOp"> { ]; } +def XlaCallModuleToCallPass : Pass<"stablehlo-xla-call-module-to-call", "ModuleOp"> { + let summary = "Convert XlaCallModuleOp to func.call op"; + let dependentDialects = [ + "TF::TensorFlowDialect", + ]; +} + +def MergeFusionWithDequantizePass : Pass<"stablehlo-merge-fusion-with-dequantize", "mlir::ModuleOp"> { + let summary = "Merge quantized conv/dot_general fusion with subsequent dequantize."; + let dependentDialects = [ + "chlo::ChloDialect", + "mlir::stablehlo::StablehloDialect", + ]; +} + def UnwrapXlaCallModuleOpPass : Pass<"stablehlo-unwrap-xla-call-module-op", "ModuleOp"> { let summary = "Unwrap XlaCallModuleOps into inline functions if not used for quantizing fused patterns."; let dependentDialects = ["TF::TensorFlowDialect"]; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc index 05d5d71d4d3c17..7d2df9e27f9220 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc @@ -140,7 +140,7 @@ void PrepareQuantizePass::runOnOperation() { MLIRContext* ctx = module_op.getContext(); auto func_op_quant_spec = GetStableHloOpQuantSpec; - auto func_op_quant_scale_spec = GetStableHloQuantScaleSpec; + auto func_op_quant_scale_spec = GetStableHloQuantConstraints; for (auto func_op : module_op.getOps()) { // The function might contain more stats ops than required, and it will diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc index 10b15f1132fe62..a6d041a5b8cb9e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -78,8 +78,6 @@ using ::stablehlo::quantization::Method; using ::stablehlo::quantization::QuantizedType; using ::stablehlo::quantization::StaticRangePtq; -constexpr StringRef kCompositeFuncPrefix = "composite_"; -constexpr StringRef kQuantizedFuncPrefix = "quantized_"; constexpr StringRef kEntryFuncAttrName = "_entry_function"; // Returns broadcasted user op of an input op. Returns null if @@ -515,9 +513,35 @@ class QuantizeSingularOpPattern : public EntryFuncBodyQuantizationPattern { void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, PatternRewriter& rewriter) const override { auto singular_op = *entry_func_op.getOps().begin(); - Value singular_op_result = singular_op.getResult(); - singular_op_result.setType(entry_func_op.getResultTypes()[0]); + + // For ops that require same operand and result types, use explicit + // requantize op rather than using `entry_func_op`'s result as op result. + auto spec = GetStableHloQuantConstraints(singular_op); + const bool has_same_operand_and_result_type = + spec->has_same_operand_and_result_type_requirement; + if (has_same_operand_and_result_type) { + const Type operand_type = entry_func_op.getArgumentTypes()[0]; + const Type func_result_type = entry_func_op.getResultTypes()[0]; + + // Get the quantized tensor manipulation op's output type and update. + const auto singular_op_result_type = + singular_op_result.getType().cast(); + const ArrayRef singular_op_shape = + singular_op_result_type.getShape(); + const TensorType new_singular_op_result_type = + singular_op_result_type.cloneWith( + singular_op_shape, + getElementTypeOrSelf(operand_type).cast()); + singular_op_result.setType(new_singular_op_result_type); + + // Create requantization op and return. + rewriter.setInsertionPointAfter(singular_op); + CreateAndReturnUniformQuantizeOp(rewriter, *singular_op, entry_func_op, + func_result_type); + } else { + singular_op_result.setType(entry_func_op.getResultTypes()[0]); + } } }; @@ -543,6 +567,29 @@ void QuantizeEntryFuncOp( entry_func_op.setSymName(quantized_function_name); } +// Replaces `xla_call_module_op` with a newly created `func::CallOp`, where the +// callee is `callee_func_op`. The existence of `kQuantizationMethodAttr` in +// `xla_call_module_op` should be guaranteed. +void ReplaceXlaCallModuleOpWithNewCallOp(TF::XlaCallModuleOp xla_call_module_op, + func::FuncOp callee_func_op, + PatternRewriter& rewriter) { + OpBuilder::InsertionGuard insertion_guard(rewriter); + + // Create a new `CallOp` that calls `callee_func_op`. + rewriter.setInsertionPoint(xla_call_module_op); + auto call_op = + rewriter.create(xla_call_module_op.getLoc(), callee_func_op, + xla_call_module_op.getArgs()); + + // Transfer the `kQuantizationMethodAttr` attribute to the `CallOp`, + // indicating what `Method` has been applied to the quantized unit. + call_op->setAttr( + kQuantizationMethodAttr, + xla_call_module_op->getAttrOfType(kQuantizationMethodAttr)); + + rewriter.replaceOp(xla_call_module_op, call_op); +} + // Replaces a quantized `xla_call_module_op` with a `func::CallOp`. The callee // is expected to remain unquantized (thus having a signature mismatch), and it // is also quantized accordingly. @@ -558,10 +605,8 @@ void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( QuantizeEntryFuncOp(ctx, rewriter, xla_call_module_op, entry_func_op, body_rewrite_pattern, quantization_method); - // Replace the XlaCallModuleOp with a new CallOp. - rewriter.setInsertionPoint(xla_call_module_op); - rewriter.replaceOpWithNewOp(xla_call_module_op, entry_func_op, - xla_call_module_op.getArgs()); + ReplaceXlaCallModuleOpWithNewCallOp(xla_call_module_op, entry_func_op, + rewriter); } // Pattern that mainly does two things: @@ -593,6 +638,10 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { ModuleOp module_op = op->getParentOfType(); SymbolTable symbol_table(module_op); + // Ignore ops without quantization method. + // Consider adding checks for individual methods. + if (!op->getAttr(kQuantizationMethodAttr)) return failure(); + // Ignore unquantized ops. if (!IsQuantizedXlaCallModuleOp(op)) return failure(); @@ -664,7 +713,7 @@ class QuantizeOpWithRegionPattern // Quantization parameters can be propagated only for same-scale ops and // same-scale ops are quantized only when they are connected to quantized // composite functions. - if (!GetStableHloQuantScaleSpec(op_with_region) + if (!GetStableHloQuantConstraints(op_with_region) ->has_same_scale_requirement || !IsConnectedWithQuantizedCompsiteFunction(op_with_region)) { return failure(); @@ -866,7 +915,8 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { } // Check whether the preceding op is a quantized same-scale op. - if (GetStableHloQuantScaleSpec(preceding_op)->has_same_scale_requirement) { + if (GetStableHloQuantConstraints(preceding_op) + ->has_same_scale_requirement) { for (const OpResult result : preceding_op->getResults()) { const Type element_type = getElementTypeOrSelf(result.getType()); if (element_type.isa()) { @@ -893,7 +943,7 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { } // Check whether the following op is a quantized same-scale op. - if (GetStableHloQuantScaleSpec(following_op) + if (GetStableHloQuantConstraints(following_op) ->has_same_scale_requirement) { for (Value operand : following_op->getOperands()) { const Type element_type = getElementTypeOrSelf(operand.getType()); @@ -923,7 +973,9 @@ class QuantizeWeightOnlyOpPattern : public EntryFuncBodyQuantizationPattern { }; // Compute heavy patterns should be quantized for both server and ODML targets. -void PopulateComputeHeavyPatterns( +// Most patterns here are useful when quantized since they are compute heavy +// or memory bound. +void PopulateCommonQuantizationPatterns( MLIRContext& ctx, RewritePatternSet& patterns, const bool enable_per_channel_quantized_weight) { patterns.add>( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h index 9aa33ee0316ee1..67eb267c1d9037 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h @@ -148,7 +148,7 @@ class StableHloQuantizationPattern : public OpRewritePattern { return failure(); } - if (GetStableHloQuantScaleSpec(candidate_op) + if (GetStableHloQuantConstraints(candidate_op) ->has_same_scale_requirement && !IsConnectedWithQuantizedCompsiteFunction(candidate_op)) { return failure(); @@ -250,9 +250,10 @@ class StableHloQuantizationPattern : public OpRewritePattern { } }; -// Populates pattern for compute heavy operations. -void PopulateComputeHeavyPatterns(MLIRContext& ctx, RewritePatternSet& patterns, - bool enable_per_channel_quantized_weight); +// Populates common patterns that are usually compute heavy or memory bound. +void PopulateCommonQuantizationPatterns( + MLIRContext& ctx, RewritePatternSet& patterns, + bool enable_per_channel_quantized_weight); // Populates conversion patterns for all quantizable ops, including // ops that are not compute-heavy and data movement ops. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index 8bb2bd33564481..0000057402886f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -103,10 +103,8 @@ class QuantizePass : public impl::QuantizePassBase { using impl::QuantizePassBase::QuantizePassBase; explicit QuantizePass(const bool enable_per_channel_quantized_weight, - const bool enable_full_int_quantization, const bool enable_weight_only) { enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; - enable_full_int_quantization_ = enable_full_int_quantization; enable_weight_only_ = enable_weight_only; } @@ -125,13 +123,11 @@ void QuantizePass::runOnOperation() { PopulateQuantizeWeightOnlyPatterns(ctx, patterns); } - PopulateComputeHeavyPatterns(ctx, patterns, - enable_per_channel_quantized_weight_); + PopulateCommonQuantizationPatterns(ctx, patterns, + enable_per_channel_quantized_weight_); // Quantize all quantizable ops, including ops that are not compute-heavy. - if (enable_full_int_quantization_) { - PopulateAllQuantizablePatterns(ctx, patterns); - } + PopulateAllQuantizablePatterns(ctx, patterns); if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { // There are cases where no rewrites happen even if a pattern matches, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index f3cf92dde359d1..1efc5d40c7ce20 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include "absl/log/log.h" #include "absl/status/status.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project // IWYU pragma: keep #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -26,6 +27,7 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" @@ -41,7 +43,6 @@ namespace mlir::quant::stablehlo { namespace { -using QuantMethod = tensorflow::quantization::QuantizationMethod::PresetMethod; using ::tensorflow::quantization::RunPassesOnModuleOp; class QuantizeCompositeFunctionsPass @@ -55,9 +56,8 @@ class QuantizeCompositeFunctionsPass explicit QuantizeCompositeFunctionsPass( const bool enable_per_channel_quantized_weight, - const bool enable_weight_only, const bool enable_full_int_quantization) { + const bool enable_weight_only) { enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; - enable_full_int_quantization_ = enable_full_int_quantization; enable_weight_only_ = enable_weight_only; } @@ -90,21 +90,34 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { QuantizePassOptions quantize_options; quantize_options.enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight_; - quantize_options.enable_full_int_quantization_ = - enable_full_int_quantization_; quantize_options.enable_weight_only_ = enable_weight_only_; // QuantizePass modifies FuncOps referenced outside of its given scope // and therefore requires a module-level context. pm.addPass(createQuantizePass(quantize_options)); pm.addNestedPass(createPostQuantizePass()); + // Convert XlaCallModuleOps lifted but not quantized to func.call op. + // The reasons these ops are not quantized may be: + // 1. Disabled due to selective quantization. + // 2. Not supported, e.g. add op for server. + pm.addPass(createXlaCallModuleToCallPass()); + + // TODO: b/321729008 - move this implementation to quantization_patterns.cc. + if (merge_fusion_with_dequantize_) { + pm.addPass(createMergeFusionWithDequantizePass()); + } + ModuleOp module_op = getOperation(); if (const absl::Status pm_run_status = RunPassesOnModuleOp(mlir_dump_file_name_, pm, module_op); !pm_run_status.ok()) { signalPassFailure(); } + + // Emit human-readable quantization report. + const QuantizationReport report(module_op); + report.Print(); } -} // namespace +} // namespace } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc index 5209f6be325979..6ed82c125b0be9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -488,7 +489,7 @@ void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass:: func::FuncOp main_func = FindMainFuncOp(module_op); if (!main_func) return; - // To handle the case where `main` function has tf.StatefulPartitionedCallOp, + // In case the model has tf.StatefulPartitionedCallOp or tf.PartitionedCallOp, // we recursively find called functions and process StableHLO ops in them. SmallVector func_ops; func_ops.push_back(main_func); @@ -499,6 +500,10 @@ void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass:: if (!main_func) continue; SymbolTable symbol_table(module_op); + for (auto call_op : main_func.getOps()) { + func_ops.push_back(dyn_cast_or_null(symbol_table.lookup( + call_op.getFAttr().cast().getValue()))); + } for (auto call_op : main_func.getOps()) { func_ops.push_back( dyn_cast_or_null(symbol_table.lookup(call_op.getF()))); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc index d596d1885c8066..bdf7f311f26bfa 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc @@ -24,12 +24,12 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep #include "stablehlo/dialect/VhloOps.h" // from @stablehlo // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" // IWYU pragma: keep -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep namespace mlir::quant::stablehlo::testing { @@ -39,8 +39,9 @@ namespace mlir::quant::stablehlo::testing { namespace { +using ::stablehlo::quantization::ExpandPresets; using ::stablehlo::quantization::PipelineConfig; -using ::stablehlo::quantization::StaticRangePtqPreset; +using ::stablehlo::quantization::QuantizationConfig; class TestPostCalibrationComponentPass : public impl::TestPostCalibrationComponentPassBase< @@ -61,12 +62,16 @@ void TestPostCalibrationComponentPass::runOnOperation() { OpPassManager pm(ModuleOp::getOperationName()); - StaticRangePtqPreset static_range_ptq_preset; + QuantizationConfig config = QuantizationConfig::default_instance(); + config.mutable_static_range_ptq_preset(); + + const QuantizationConfig new_config = ExpandPresets(config); + PipelineConfig pipeline_config; pipeline_config.set_unpack_quantized_types(unpack_quantized_types_); PostCalibrationComponent component(&ctx); - component.AddPasses(pm, static_range_ptq_preset, pipeline_config); + component.AddPasses(pm, new_config.specs(), pipeline_config); if (failed(runPipeline(pm, module_op))) { signalPassFailure(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/xla_call_module_to_call.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/xla_call_module_to_call.cc new file mode 100644 index 00000000000000..123244db3b7dbb --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/xla_call_module_to_call.cc @@ -0,0 +1,83 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::quant::stablehlo { + +#define GEN_PASS_DEF_XLACALLMODULETOCALLPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" + +namespace { + +// Converts XlaCallModuleOps to func.call. +class XlaCallModuleToCallPass + : public impl::XlaCallModuleToCallPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XlaCallModuleToCallPass) + + explicit XlaCallModuleToCallPass() = default; + + private: + void runOnOperation() override; +}; + +// Converts XlaCallModuleOps to func.call. +class XlaCallModuleOpToCallOp : public OpRewritePattern { + public: + explicit XlaCallModuleOpToCallOp(MLIRContext* context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(TF::XlaCallModuleOp op, + PatternRewriter& rewriter) const override { + auto module_op = op->getParentOfType(); + SymbolTable symbol_table(module_op); + + auto entry_func_op = dyn_cast_or_null( + symbol_table.lookup(GetEntryFunctionName(op))); + if (!entry_func_op) return failure(); + + // Replace the XlaCallModuleOp with a new CallOp. + rewriter.replaceOpWithNewOp(op, entry_func_op, op.getArgs()); + return success(); + } +}; + +void XlaCallModuleToCallPass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext* ctx = module_op.getContext(); + RewritePatternSet patterns(&getContext()); + patterns.add(ctx); + if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index 2b20cc48a89d69..df5252b986adf5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -60,8 +60,10 @@ pytype_strict_library( # "//tensorflow/python/ops:array_ops", # "//tensorflow/python/ops:math_ops", # "//tensorflow/python/ops:nn_ops", +# "//tensorflow/python/ops:variables", # "//tensorflow/python/platform:client_testlib", # "//tensorflow/python/saved_model:load", +# "//tensorflow/python/saved_model:loader", # "//tensorflow/python/saved_model:save", # "//tensorflow/python/types:core", # "@absl_py//absl/testing:parameterized", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py index 80a2c560ef865b..f65c56bc577742 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -64,6 +64,7 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): ([10, 1, 1024], [10, 1024, 3]), ([2, 3, 1, 1024], [2, 3, 1024, 3]), ), + 'merge_fusion_with_dequantize': (False, True), }]) ) @test_util.run_in_graph_and_eager_modes @@ -72,6 +73,7 @@ def test_matmul_ptq_model( bias_fn: Optional[ops.Operation], activation_fn: Optional[ops.Operation], dim_sizes: Sequence[int], + merge_fusion_with_dequantize: bool, ): lhs_dim_size, rhs_dim_size = dim_sizes input_shape = (*lhs_dim_size,) @@ -115,6 +117,9 @@ def data_gen() -> repr_dataset.RepresentativeDataset: ] ), tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + pipeline_config=qc.PipelineConfig( + merge_fusion_with_dequantize=merge_fusion_with_dequantize + ), ) quantization.quantize_saved_model( self._input_saved_model_path, @@ -150,6 +155,19 @@ def data_gen() -> repr_dataset.RepresentativeDataset: 0.65, ) + if merge_fusion_with_dequantize: + # Check activation functions are explicitly present. + # If present the last op before return should be stablehlo.clamp for relu6 + # and stablehlo.maximum for relu. + if activation_fn is nn_ops.relu6: + self.assertRegex(module_str, r'stablehlo.clamp.*\n.*return') + elif activation_fn is nn_ops.relu: + self.assertRegex(module_str, r'stablehlo.maximum.*\n.*return') + else: + # Check activation functions are implicit. + self.assertNotRegex(module_str, r'stablehlo.clamp.*\n.*return') + self.assertNotRegex(module_str, r'stablehlo.maximum.*\n.*return') + @parameterized.parameters( testing.parameter_combinations([{ 'same_scale_op': ( @@ -342,6 +360,8 @@ def data_gen() -> repr_dataset.RepresentativeDataset: False, True, ), + 'merge_fusion_with_dequantize': (False, True), + 'has_func_alias': (False, True), }]) ) @test_util.run_in_graph_and_eager_modes @@ -352,7 +372,9 @@ def test_conv_ptq_model( has_batch_norm: bool, input_shape_dynamic: bool, enable_per_channel_quantized_weight: bool, + merge_fusion_with_dequantize: bool, dilations: Sequence[int] = None, + has_func_alias: bool = False, ): input_shape = (None, 3, 4, 3) if input_shape_dynamic else (1, 3, 4, 3) filter_shape = (2, 3, 3, 2) @@ -366,15 +388,16 @@ def test_conv_ptq_model( has_batch_norm, strides, dilations, + 'SAME', + has_func_alias, ) - # TODO(b/331809306): investigate why these tests fail. - # skip these test cases. - if ( - bias_fn is None - and has_batch_norm - and input_shape_dynamic - and enable_per_channel_quantized_weight - ): + # TODO: b/331809306 - Investigate why these test fail then re-enable. + if has_batch_norm and (bias_fn or not input_shape_dynamic): + return + + # TODO: b/331120943 - Re-enable this after correctly handling quantization + # granularity per quantizable scope. + if has_batch_norm and (not bias_fn and input_shape_dynamic): return # Generate model input data. @@ -410,6 +433,9 @@ def data_gen() -> repr_dataset.RepresentativeDataset: enable_per_channel_quantized_weight=enable_per_channel_quantized_weight, ), tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + pipeline_config=qc.PipelineConfig( + merge_fusion_with_dequantize=merge_fusion_with_dequantize + ), ) quantization.quantize_saved_model( self._input_saved_model_path, @@ -445,6 +471,27 @@ def data_gen() -> repr_dataset.RepresentativeDataset: 0.61, ) + if merge_fusion_with_dequantize: + # Check activation functions are explicitly present. + # If present the last op before return should be stablehlo.clamp for relu6 + # and stablehlo.maximum for relu. + if activation_fn is nn_ops.relu6: + self.assertRegex(module_str, r'stablehlo.clamp.*\n.*return') + elif activation_fn is nn_ops.relu: + self.assertRegex(module_str, r'stablehlo.maximum.*\n.*return') + else: + # Check activation functions are implicit. + self.assertNotRegex(module_str, r'stablehlo.clamp.*\n.*return') + self.assertNotRegex(module_str, r'stablehlo.maximum.*\n.*return') + + if has_func_alias: + func_aliases = self._get_function_aliases( + self._output_saved_model_path, [tag_constants.SERVING] + ) + self.assertCountEqual( + func_aliases.values(), [quantize_model_test_base.FUNC_ALIAS] + ) + @parameterized.parameters( testing.parameter_combinations([{ 'equation': ( @@ -528,6 +575,66 @@ def data_gen() -> repr_dataset.RepresentativeDataset: 0.65, ) + @parameterized.named_parameters( + ('use_constant_with_int32_input', np.int32, False), + ('use_variable_with_int32_input', np.int32, True), + ('use_constant_with_int64_input', np.int64, False), + ('use_variable_with_int64_input', np.int64, True), + ) + @test_util.run_v2_only + def test_gather_model(self, input_type, use_variable): + model = self._create_gather_model(input_type, use_variable) + + save.save(model, self._input_saved_model_path) + + rng = np.random.default_rng(seed=42) + static_input_shape = [6] + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(100): + yield { + 'input_tensor': rng.uniform( + low=0.0, high=10, size=static_input_shape + ).astype(input_type) + } + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + config = qc.QuantizationConfig( + static_range_ptq_preset=qc.StaticRangePtqPreset( + representative_datasets=[ + qc.RepresentativeDatasetConfig( + tf_record=qc.TfRecordFile(path=dataset_path) + ) + ] + ), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + module_str = self._extract_first_xla_call_module_op( + self._output_saved_model_path + ) + self.assertTrue(re.search('stablehlo.gather.*xi8>', module_str)) + + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 1 / 3, + ) + def test_when_preset_not_srq_raises_error(self): self._create_matmul_model( input_shape=(1, 1024), @@ -985,7 +1092,7 @@ def test_matmul_weight_only_model( ) config = qc.QuantizationConfig( - weight_only_preset=qc.WeightOnlyPreset(), + weight_only_ptq_preset=qc.WeightOnlyPtqPreset(), tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), ) quantization.quantize_saved_model( @@ -1010,9 +1117,8 @@ def test_matmul_weight_only_model( self._output_saved_model_path ) - # Tests that the output graph contains subtract and multiply for + # Tests that the output graph contains multiply for symmetric # dequantization. - self.assertTrue(re.search('stablehlo.subtract', module_str)) self.assertTrue(re.search('stablehlo.multiply', module_str)) # Tests that the output graph contains float dot_general. self.assertTrue( @@ -1043,6 +1149,7 @@ def test_matmul_weight_only_model( False, True, ), + 'has_func_alias': (False, True), }]) ) @test_util.run_in_graph_and_eager_modes @@ -1053,6 +1160,7 @@ def test_conv_weight_only_model( has_batch_norm: bool, input_shape_dynamic: bool, dilations: Sequence[int] = None, + has_func_alias: bool = False, ): input_shape = (None, 3, 4, 3) if input_shape_dynamic else (1, 3, 4, 3) filter_shape = (2, 3, 3, 2) @@ -1066,6 +1174,8 @@ def test_conv_weight_only_model( has_batch_norm, strides, dilations, + 'SAME', + has_func_alias, ) rng = np.random.default_rng(1234) @@ -1077,7 +1187,7 @@ def test_conv_weight_only_model( ) config = qc.QuantizationConfig( - weight_only_preset=qc.WeightOnlyPreset(), + weight_only_ptq_preset=qc.WeightOnlyPtqPreset(), tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), ) quantization.quantize_saved_model( @@ -1111,14 +1221,142 @@ def test_conv_weight_only_model( re.search('stablehlo.convolution.*xf32>.*xf32>.*xf32>', module_str) ) + if has_func_alias: + func_aliases = self._get_function_aliases( + self._output_saved_model_path, [tag_constants.SERVING] + ) + self.assertCountEqual( + func_aliases.values(), [quantize_model_test_base.FUNC_ALIAS] + ) + # Due to other meta data, the compression is not exactly 1/4. self.assertLess( testing.get_size_ratio( self._output_saved_model_path, self._input_saved_model_path ), - 0.35, + 0.4, + ) + + @parameterized.parameters( + testing.parameter_combinations([{ + 'shape_dynamic': ( + False, + True, + ), + }]) + ) + @test_util.run_in_graph_and_eager_modes + def test_add_ptq_model( + self, + shape_dynamic: bool, + ): + input_shape = (None, 3, 4, 3) if shape_dynamic else (2, 3, 4, 3) + self._create_add_model( + input_shape, + self._input_saved_model_path, ) + # Generate model input data. + rng = np.random.default_rng(seed=42) + static_input_shape = [dim if dim is not None else 2 for dim in input_shape] + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(100): + yield { + 'input_tensor': rng.uniform( + low=0.0, high=1.0, size=static_input_shape + ).astype(np.float32) + } + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + config = qc.QuantizationConfig( + static_range_ptq_preset=qc.StaticRangePtqPreset( + representative_datasets=[ + qc.RepresentativeDatasetConfig( + tf_record=qc.TfRecordFile(path=dataset_path) + ) + ], + ), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + self.assertEqual( + self._get_num_xla_call_module_op(self._output_saved_model_path), 1 + ) + module_str = self._extract_first_xla_call_module_op( + self._output_saved_model_path + ) + + # Check add is not quantized. + self.assertTrue(re.search(r'stablehlo.add.*f32>', module_str)) + + @parameterized.parameters( + testing.parameter_combinations([{ + 'shape_dynamic': ( + False, + True, + ), + }]) + ) + @test_util.run_in_graph_and_eager_modes + def test_add_weight_only_model( + self, + shape_dynamic: bool, + ): + input_shape = (None, 3, 4, 3) if shape_dynamic else (2, 3, 4, 3) + self._create_add_model( + input_shape, + self._input_saved_model_path, + ) + + # Generate model input data. + rng = np.random.default_rng(seed=42) + static_input_shape = [dim if dim is not None else 2 for dim in input_shape] + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(100): + yield { + 'input_tensor': rng.uniform( + low=0.0, high=1.0, size=static_input_shape + ).astype(np.float32) + } + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + config = qc.QuantizationConfig( + weight_only_ptq_preset=qc.WeightOnlyPtqPreset(), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + self.assertEqual( + self._get_num_xla_call_module_op(self._output_saved_model_path), 1 + ) + module_str = self._extract_first_xla_call_module_op( + self._output_saved_model_path + ) + + # Check add is not quantized. + self.assertTrue(re.search(r'stablehlo.add.*f32>', module_str), module_str) + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py index d71c89e15d313f..31c53a4cf20fe9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py @@ -31,11 +31,15 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.saved_model import load +from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import save as saved_model_save from tensorflow.python.types import core +FUNC_ALIAS = 'some_alias' + class QuantizedModelTest(test.TestCase, parameterized.TestCase): """Base test class for StableHLO quant tests.""" @@ -72,6 +76,29 @@ def _extract_first_xla_call_module_op( return str(stablehlo_module) raise ValueError('No XlaCallModule found in saved model.') + def _get_num_xla_call_module_op(self, output_saved_model_path: str) -> int: + """Gets the number of XlaCallModule ops in the output saved model.""" + root = load.load(output_saved_model_path) + tf_graph_def = root.signatures['serving_default'].graph.as_graph_def() + count = 0 + for node_def in tf_graph_def.node: + if node_def.op == 'XlaCallModule': + count += 1 + for function in tf_graph_def.library.function: + for node_def in function.node_def: + if node_def.op == 'XlaCallModule': + count += 1 + return count + + def _get_function_aliases( + self, output_saved_model_path: str, tags: List[str] + ) -> dict[str, str]: + """Gets the function aliases in the output saved model.""" + loader = loader_impl.SavedModelLoader(output_saved_model_path) + return loader.get_meta_graph_def_from_tags( + tags + ).meta_info_def.function_aliases + def _create_matmul_model( self, input_shape: Sequence[int], @@ -238,6 +265,7 @@ def _create_conv2d_model( strides: Sequence[int] = (1, 1, 1, 1), dilations: Sequence[int] = (1, 1, 1, 1), padding: str = 'SAME', + has_func_alias: bool = False, ) -> module.Module: class ConvModel(module.Module): """A simple model with a single conv2d, bias and relu.""" @@ -294,6 +322,11 @@ def conv2d(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: return {'output': out} model = ConvModel() + save_options = None + if has_func_alias: + save_options = tensorflow.saved_model.SaveOptions( + function_aliases={FUNC_ALIAS: model.conv2d} + ) saved_model_save.save( model, saved_model_path, @@ -302,6 +335,76 @@ def conv2d(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: shape=input_shape, dtype=dtypes.float32, name='input_tensor' ) ), + options=save_options, + ) + return model + + def _create_gather_model(self, input_type, use_variable) -> module.Module: + class GatherModel(module.Module): + """A simple model with a single gather.""" + + def __init__(self, use_variable): + """Initializes a GatherModel. + + Args: + use_variable: If True, creates a variable for weight. + """ + super().__init__() + w_val = np.random.randn(128, 32).astype('f4') + if use_variable: + self.w = variables.Variable(w_val) + else: + self.w = w_val + + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec( + shape=[6], dtype=input_type, name='input_tensor' + ) + ] + ) + def __call__( + self, input_tensor: core.Tensor + ) -> Mapping[str, core.Tensor]: + """Performs a gather operation.""" + out = array_ops.gather_v2(self.w, input_tensor) + return {'output': out} + + return GatherModel(use_variable) + + def _create_add_model( + self, + shape: Sequence[int], + saved_model_path: str, + ) -> module.Module: + class AddModel(module.Module): + """A simple model with a single add.""" + + def __init__(self): + pass + + @def_function.function + def add(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: + """Performs an add operation. + + Args: + input_tensor: Input tensor to perform add on. + + Returns: + A map of: output key -> output result. + """ + out = math_ops.add(input_tensor, input_tensor) + return {'output': out} + + model = AddModel() + saved_model_save.save( + model, + saved_model_path, + signatures=model.add.get_concrete_function( + tensor_spec.TensorSpec( + shape=shape, dtype=dtypes.float32, name='input_tensor' + ) + ), ) return model diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/stablehlo_quantizer_odml_oss.ipynb b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/stablehlo_quantizer_odml_oss.ipynb new file mode 100644 index 00000000000000..858401154bec3b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/stablehlo_quantizer_odml_oss.ipynb @@ -0,0 +1,192 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "tWhm0JFMPJ5I" + }, + "source": [ + "Copyright 2024 Google LLC.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RJcqTAlfPQjk" + }, + "source": [ + "# [OSS] JAX to TFLite with StableHLO Quantization Demonstration for ODML." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cqeGmbO6PPNd" + }, + "source": [ + "This example shows a JAX Keras reference model converted into a StableHLO module and via `jax2tf`, then quantized in the ODML Converter via the StableHLO Quantizer.\n", + "\n", + "Note: This API is experimental and will likely have breakages with other models. Please reach out to [scalable-opt-team@google.com](mailto:scalable-opt-team@google.com) and we will support your use case." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-S0P42BpPSeJ" + }, + "source": [ + "## StableHLO Quantizer\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FacwMD9MPUew" + }, + "source": [ + "StableHLO Quantizer is a quantization API to enable ML framework optionality and hardware retargetability." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RXZUHZQoQZOo" + }, + "outputs": [], + "source": [ + "!pip uninstall tensorflow --yes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aYz36YEKPYRk" + }, + "outputs": [], + "source": [ + "!pip3 install tf-nightly\n", + "!pip3 install keras-core" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "duab6P-nPZzF" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "print(\"TensorFlow version:\", tf.__version__)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "c9JX9RJTPaoW" + }, + "outputs": [], + "source": [ + "import os\n", + "os.environ['KERAS_BACKEND'] = 'jax'\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "from keras_core.applications import ResNet50\n", + "from jax.experimental import jax2tf" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rTcHwDPBPchd" + }, + "outputs": [], + "source": [ + "input_shape = (1, 224, 224, 3)\n", + "\n", + "jax_callable = jax2tf.convert(\n", + " ResNet50(\n", + " input_shape=input_shape[1:],\n", + " pooling='avg',\n", + " ).call,\n", + " with_gradient=False,\n", + " native_serialization=True,\n", + " native_serialization_platforms=('cpu',))\n", + "\n", + "tf_module = tf.Module()\n", + "tf_module.f = tf.function(\n", + " jax_callable,\n", + " autograph=False,\n", + " input_signature=[\n", + " tf.TensorSpec(input_shape, jnp.float32, 'lhs_operand')\n", + " ],\n", + ")\n", + "\n", + "saved_model_dir = '/tmp/saved_model'\n", + "tf.saved_model.save(tf_module, saved_model_dir)\n", + "\n", + "def calibration_dataset():\n", + " rng = np.random.default_rng(seed=1235)\n", + " for _ in range(2):\n", + " yield {\n", + " 'lhs_operand': rng.uniform(low=-1.0, high=1.0, size=input_shape).astype(\n", + " np.float32\n", + " )\n", + " }\n", + "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n", + "converter.target_spec.supported_ops = [\n", + " tf.lite.OpsSet.SELECT_TF_OPS, # enable TensorFlow ops.\n", + " tf.lite.OpsSet.TFLITE_BUILTINS, # enable TFL ops.\n", + "]\n", + "converter.representative_dataset = calibration_dataset\n", + "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", + "# Below flag controls whether to use StableHLO Quantizer or TFLite quantizer.\n", + "converter.experimental_use_stablehlo_quantizer = True\n", + "\n", + "quantized_model = converter.convert()\n", + "\n", + "with open('/tmp/resnet50_quantized.tflite', 'wb') as f:\n", + " f.write(quantized_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "u3b9Xj8dPdXo" + }, + "outputs": [], + "source": [ + "print(str(os.path.getsize('/tmp/resnet50_quantized.tflite') \u003e\u003e 20) + 'MB')" + ] + } + ], + "metadata": { + "colab": { + "private_outputs": true, + "provenance": [ + { + "file_id": "https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/stablehlo_quantizer_odml_oss.ipynb", + "timestamp": 1712841250910 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py index aa3745a3fdd453..f9a1a90e071453 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py @@ -43,6 +43,16 @@ def _serialize_signature_def_map( return signature_def_map_serialized +def _has_quantization_method( + quantization_specs: qc.QuantizationSpecs, method: str +) -> bool: + """Returns whether a given QuantizationSpecs has the given quantization method.""" + for spec in quantization_specs.specs: + if spec.method.HasField(method): + return True + return False + + # TODO: b/310594193 - Export API to pip package. def quantize_saved_model( src_saved_model_path: str, @@ -60,15 +70,6 @@ def quantize_saved_model( ValueError: When `config` was not configured for static-range PTQ single representative dataset. """ - if not ( - config.HasField('static_range_ptq_preset') - and len(config.static_range_ptq_preset.representative_datasets) == 1 - ) and not config.HasField('weight_only_preset'): - raise ValueError( - '`quantize_saved_model` currently only supports static-range PTQ with a' - ' single signature or weight-only quantization.' - ) - # Updates user-provided `QuantizationConfig`s for the internal quantization # pipeline to work with. print('=== User-provided QuantizationConfig ===') @@ -82,6 +83,15 @@ def quantize_saved_model( print('=== Updated QuantizationConfig ===') print(config) + if not ( + _has_quantization_method(config.specs, 'static_range_ptq') + and len(config.calibration_options.representative_datasets) == 1 + ) and not _has_quantization_method(config.specs, 'weight_only_ptq'): + raise ValueError( + '`quantize_saved_model` currently only supports static-range PTQ with a' + ' single signature or weight-only quantization.' + ) + signature_def_map = save_model.get_signatures_from_saved_model( src_saved_model_path, signature_keys=None, @@ -89,7 +99,9 @@ def quantize_saved_model( ) signature_def_map_serialized = _serialize_signature_def_map(signature_def_map) - if config.HasField('static_range_ptq_preset'): + # Currently, only StaticRangePtq or WeightOnlyPtq is supported. + # Consider merging the pipelines to address mixed algorithm models. + if _has_quantization_method(config.specs, 'static_range_ptq'): pywrap_quantization.static_range_ptq( src_saved_model_path, dst_saved_model_path, @@ -98,7 +110,7 @@ def quantize_saved_model( signature_def_map_serialized=signature_def_map_serialized, py_function_library=py_function_lib.PyFunctionLibrary(), ) - elif config.HasField('weight_only_preset'): + elif _has_quantization_method(config.specs, 'weight_only_ptq'): pywrap_quantization.weight_only_ptq( src_saved_model_path, dst_saved_model_path, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto index efdceebd6c2008..81f2ff3686fbbe 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto @@ -77,8 +77,9 @@ message StaticRangePtqPreset { bool enable_full_int_quantization = 3; } -// Applies int8 per-tensor weight-only quantization for all dot_general op. -message WeightOnlyPreset {} +// Applies int8 per-tensor weight-only post-training quantization for all +// dot_general op. +message WeightOnlyPtqPreset {} // Metadata specific to the input TensorFlow SavedModel, which may be required // to identify the specific MetaGraphDef to quantize, for example. @@ -96,6 +97,12 @@ message PipelineConfig { // hardware performs better with integer ops. // Default value: true optional bool unpack_quantized_types = 1; + + // When set to True, requantize op in the quantized fusion will merge with the + // subsequent dequantize op if present. + // Default value: false + // TODO: b/321729008 - re-consider default value after testing on prod model. + bool merge_fusion_with_dequantize = 2; } // Represents a single quantizable unit, a (nearly) minimum unit of work when @@ -158,6 +165,12 @@ message StaticRangePtq { map input_quantized_types = 1; } +message WeightOnlyPtq { + // Operand index -> QuantizedType mapping. Operands that are not specified + // here will be quantized with best effort. + map input_quantized_types = 1; +} + // Represents a matching method that matches quantizable units by lifted // functions' names. message FunctionNameMatcherSpec { @@ -178,6 +191,7 @@ message Method { oneof method { NoQuantization no_quantization = 1; StaticRangePtq static_range_ptq = 2; + WeightOnlyPtq weight_only_ptq = 3; } } @@ -322,7 +336,7 @@ message QuantizationConfig { oneof preset { // Performs best-effort static-range post-training quantization (PTQ). StaticRangePtqPreset static_range_ptq_preset = 1; - WeightOnlyPreset weight_only_preset = 7; + WeightOnlyPtqPreset weight_only_ptq_preset = 7; } // TF SavedModel specific information for the input model. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert_tf_quant_ops_to_mhlo.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert_tf_quant_ops_to_mhlo.mlir index 4e883aa0e11c70..32e605ba7bf0dc 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert_tf_quant_ops_to_mhlo.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert_tf_quant_ops_to_mhlo.mlir @@ -25,7 +25,7 @@ func.func @uniform_quantized_add(%input: tensor<3x2xf32>) -> tensor<3x2xf32> { %input_zps = "tf.Const"() { value = dense<4> : tensor } : () -> tensor // tensor_proto that points to dense<127> of type !tf_type.qint32. - // CHECK-DAG: %[[RHS:.*]] = mhlo.constant() {value = dense<127> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> + // CHECK-DAG: %[[RHS:.*]] = mhlo.constant() <{value = dense<127> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> %bias = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> %bias_scales = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor %bias_zps = "tf.Const"() { value = dense<4> : tensor } : () -> tensor diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir index 317da0b762e60d..2f149281fbd0be 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir @@ -8,10 +8,10 @@ // int ops. func.func @main(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> { %0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32> - %1 = "tf.CustomAggregator"(%arg0) <{id = "1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 0.999992311 : f32, max_percentile = 0.000000e+00 : f32, min = 7.547870e-07 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - %2 = "tf.XlaCallModule"(%1, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> - %3 = "tf.CustomAggregator"(%2) <{id = "2"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 18.3033524 : f32, max_percentile = 0.000000e+00 : f32, min = -17.5216827 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> - return %3 : tensor<1x3xf32> + %1:4 = "tf.CustomAggregator"(%arg0) <{id = "1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 0.999992311 : f32, max_percentile = 0.000000e+00 : f32, min = 7.547870e-07 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %2 = "tf.XlaCallModule"(%1#0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %3:4 = "tf.CustomAggregator"(%2) <{id = "2"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 5.3033524 : f32, max_percentile = 0.000000e+00 : f32, min = -3.5216827 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + return %3#0 : tensor<1x3xf32> } func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> @@ -36,10 +36,10 @@ func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: func.func @main_no_unpack(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> { %0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32> - %1 = "tf.CustomAggregator"(%arg0) <{id = "1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 0.999992311 : f32, max_percentile = 0.000000e+00 : f32, min = 7.547870e-07 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - %2 = "tf.XlaCallModule"(%1, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> - %3 = "tf.CustomAggregator"(%2) <{id = "2"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 18.3033524 : f32, max_percentile = 0.000000e+00 : f32, min = -17.5216827 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> - return %3 : tensor<1x3xf32> + %1:4 = "tf.CustomAggregator"(%arg0) <{id = "1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 0.999992311 : f32, max_percentile = 0.000000e+00 : f32, min = 7.547870e-07 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %2 = "tf.XlaCallModule"(%1#0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %3:4 = "tf.CustomAggregator"(%2) <{id = "2"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 5.3033524 : f32, max_percentile = 0.000000e+00 : f32, min = -3.5216827 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + return %3#0 : tensor<1x3xf32> } func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> @@ -47,10 +47,10 @@ func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: } // CHECK-NO-UNPACK-LABEL: func.func @main_no_unpack // CHECK-NO-UNPACK-SAME: (%[[ARG_0:.+]]: tensor<1x1024xf32>) -> tensor<1x3xf32> -// CHECK-NO-UNPACK-DAG: %[[CONST:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform:f32, {{.*}}>> +// CHECK-NO-UNPACK-DAG: %[[CONST:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform:f32:1, {{.*}}>> // CHECK-NO-UNPACK: %[[QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x1024xf32>) -> tensor<1x1024x!quant.uniform> // CHECK-NO-UNPACK: %[[DOT:.+]] = stablehlo.dot_general %[[QUANTIZE_0]], %[[CONST]] -// CHECK-NO-UNPACK: %[[QUANTIZE_1:.+]] = stablehlo.uniform_quantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-NO-UNPACK: %[[QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[DOT]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> // CHECK-NO-UNPACK: %[[DEQUANTIZE:.+]] = stablehlo.uniform_dequantize %[[QUANTIZE_1]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK-NO-UNPACK: return %[[DEQUANTIZE]] : tensor<1x3xf32> @@ -60,20 +60,15 @@ func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: func.func @main(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> { %0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32> - %2 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %2 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> return %2 : tensor<1x3xf32> } func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> return %0 : tensor<1x3xf32> } -// CHECK-LABEL: func.func @main +// CHECK: func.func @main // CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x1024xf32>) -> tensor<1x3xf32> - // CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant dense<{{.*}}> : tensor<1024x3xf32> -// CHECK: "tf.XlaCallModule"(%[[ARG_0]], %[[CONST_0]]) - -// CHECK: func.func private @composite_dot_general_fn_1 -// CHECK-SAME: attributes {_from_xla_call_module} -// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general -// CHECK-SAME: contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> +// CHECK: stablehlo.dot_general %[[ARG_0]], %[[CONST_0]] +// CHECK-NOT: tf.XlaCallModule diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir index 1fe56cde49601d..954323af9ef7ad 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir @@ -8,10 +8,10 @@ func.func @main(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { } // CHECK: @main(%[[ARG_0:.+]]: tensor<1x4xf32>) -> tensor<1x3xf32> // CHECK-DAG: %[[CST:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> -// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]] = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {{.*}} : (tensor<1x4xf32>) -> tensor<1x4xf32> +// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {{.*}} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) // CHECK: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[CST]]) // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" -// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {{.*}} : (tensor<1x3xf32>) -> tensor<1x3xf32> +// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {{.*}} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) // CHECK: return %[[CUSTOM_AGGREGATOR_1]] : tensor<1x3xf32> // CHECK: } // CHECK: } @@ -28,10 +28,10 @@ func.func @serving_default(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { } // CHECK: @serving_default(%[[ARG_0:.+]]: tensor<1x4xf32>) -> tensor<1x3xf32> // CHECK-DAG: %[[CST:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> -// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]] = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {{.*}} : (tensor<1x4xf32>) -> tensor<1x4xf32> +// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {{.*}} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) // CHECK: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[CST]]) // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" -// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {{.*}} : (tensor<1x3xf32>) -> tensor<1x3xf32> +// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {{.*}} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) // CHECK: return %[[CUSTOM_AGGREGATOR_1]] : tensor<1x3xf32> // CHECK: } // CHECK: } @@ -51,12 +51,12 @@ func.func @main(%arg0: tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32> { // [b, 0, 1, f]). The weight constant is folded into [0, 1, i, o] format. // CHECK-DAG: %[[CST:.+]] = stablehlo.constant dense<3.000000e+00> : tensor<3x3x8x8xf32> // CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose %arg0, dims = [0, 2, 3, 1] : (tensor<1x8x4x4xf32>) -> tensor<1x4x4x8xf32> -// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]] = "tf.CustomAggregator"(%[[TRANSPOSE_1]]) {{.*}} : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> +// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[TRANSPOSE_1]]) {{.*}} : (tensor<1x4x4x8xf32>) -> (tensor<1x4x4x8xf32>, tensor, tensor, tensor<0xi64>) // Corresponds to the converted `stablehlo.convolution`. Note that the shapes // correspond to the dimension numbers of: [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] // CHECK: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[CST]]) {{.*}} : (tensor<1x4x4x8xf32>, tensor<3x3x8x8xf32>) -> tensor<1x4x4x8xf32> -// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) {{.*}} : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> +// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) {{.*}} : (tensor<1x4x4x8xf32>) -> (tensor<1x4x4x8xf32>, tensor, tensor, tensor<0xi64>) // CHECK: %[[TRANSPOSE_2:.+]] = stablehlo.transpose %[[CUSTOM_AGGREGATOR_1]], dims = [0, 3, 1, 2] : (tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> // CHECK: return %[[TRANSPOSE_2]] : tensor<1x8x4x4xf32> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/tf_to_stablehlo.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/tf_to_stablehlo.mlir index 240b10d8438431..ec757bc96effaa 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/tf_to_stablehlo.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/tf_to_stablehlo.mlir @@ -114,7 +114,7 @@ func.func @func_conv_batchnorm_relu6_dynamic(%arg_0: tensor) -> (te // This test makes sure functions with tf._noinline=true is not inlined. module { - func.func @partitioned_call(%arg0: tensor<1x2x2x3xf32>) -> (tensor<1x2x2x3xf32>) { + func.func @stateful_partitioned_call(%arg0: tensor<1x2x2x3xf32>) -> (tensor<1x2x2x3xf32>) { %0 = "tf.StatefulPartitionedCall"(%arg0) <{ config = "", config_proto = "", executor_type = "", f = @some_func }> { @@ -139,7 +139,7 @@ module { module { func.func @partitioned_call(%arg0: tensor<1x2x2x3xf32>) -> (tensor<1x2x2x3xf32>) { - %0 = "tf.StatefulPartitionedCall"(%arg0) <{ + %0 = "tf.PartitionedCall"(%arg0) <{ config = "", config_proto = "", executor_type = "", f = @some_func }> { _collective_manager_ids = [], device = "" @@ -153,6 +153,6 @@ module { } // CHECK: module -// CHECK-NOT: tf.StatefulPartitionedCall +// CHECK-NOT: tf.PartitionedCall // CHECK-NOT: some_func // CHECK-NOT: func.call diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir index 69bf09104c814d..d6afb6461c0da9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir @@ -123,3 +123,35 @@ func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { // STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: _original_entry_function // STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY-NOT: _quantization_method // STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: _tfl_quant_trait = "fully_quantizable" + +// ----- + +// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs="quantization-specs=static-range-ptq-to-all" \ +// RUN: -split-input-file | FileCheck %s --check-prefix=STATIC-RANGE-PTQ-TO-ALL + +// STATIC-RANGE-PTQ-TO-ALL-LABEL: @some_func +func.func @some_func(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + return %1 : tensor<1x1x64xf32> +} +// Tests that XlaCallModuleOp in non-main function has attributes set correctly. + +// STATIC-RANGE-PTQ-TO-ALL: %[[CONST:.+]] = stablehlo.constant dense<2.000000e+00> +// STATIC-RANGE-PTQ-TO-ALL: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) + +// Check that the `_quantization_method` attribute contains the quantization +// method in textproto format, enabling static-range PTQ. +// STATIC-RANGE-PTQ-TO-ALL-SAME: _entry_function = @composite_dot_general_fn_1 +// STATIC-RANGE-PTQ-TO-ALL-SAME: _original_entry_function +// STATIC-RANGE-PTQ-TO-ALL-SAME: _quantization_method = "static_range_ptq { }" +// STATIC-RANGE-PTQ-TO-ALL-SAME: _tfl_quant_trait = "fully_quantizable" + +// STATIC-RANGE-PTQ-TO-ALL: return %[[XLA_CALL_MODULE:.+]] : tensor<1x1x64xf32> +// STATIC-RANGE-PTQ-TO-ALL: } + +// STATIC-RANGE-PTQ-TO-ALL-LABEL: private @composite_dot_general_fn_1 +// STATIC-RANGE-PTQ-TO-ALL-SAME: tf_quant.composite_function +// STATIC-RANGE-PTQ-TO-ALL: %[[DOT_GENERAL:.+]] = stablehlo.dot_general %arg0, %arg1 +// STATIC-RANGE-PTQ-TO-ALL: return %[[DOT_GENERAL:.+]] : tensor<1x1x64xf32> +// STATIC-RANGE-PTQ-TO-ALL: } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/merge-fusion-with-dequantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/merge-fusion-with-dequantize.mlir new file mode 100644 index 00000000000000..c228b25a2903c9 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/merge-fusion-with-dequantize.mlir @@ -0,0 +1,198 @@ +// RUN: stablehlo-quant-opt %s -stablehlo-merge-fusion-with-dequantize -split-input-file -verify-diagnostics | FileCheck %s + +// Merge fusion with dequantize for relu case. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_relu_fusion + func.func private @merge_relu_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_relu_fn + // CHECK-SAME: -> tensor<1x3xf32> + %2 = call @quantized_dot_general_relu_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_relu_fn + func.func private @quantized_dot_general_relu_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + // CHECK: %[[MIN:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %arg0, %arg1 + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]] + // CHECK: %[[MAX:.*]] = chlo.broadcast_maximum %[[DQ]], %[[MIN]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Merge fusion with dequantize for relu6 case. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_relu6_fusion + func.func private @merge_relu6_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_relu6_fn + // CHECK-SAME: -> tensor<1x3xf32> + %2 = call @quantized_dot_general_relu6_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_relu6_fn + func.func private @quantized_dot_general_relu6_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + // CHECK-DAG: %[[MIN:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[MAX:.*]] = stablehlo.constant dense<6.000000e+00> : tensor + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %arg0, %arg1 + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]] + // CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[MIN]], %[[DQ]], %[[MAX]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Merge fusion with dequantize for no activation case. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_no_act_fusion + func.func private @merge_no_act_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_fn + // CHECK-SAME: -> tensor<1x3xf32> + %2 = call @quantized_dot_general_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_fn + func.func private @quantized_dot_general_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %arg0, %arg1 + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]] + // CHECK: return %[[DQ]] : tensor<1x3xf32> + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Do not merge when quant.uniform result is used directly. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @no_merge_fusion_direct_usage + func.func private @no_merge_fusion_direct_usage(%arg0: tensor<1x4xf32>) -> (tensor<1x3xf32>, tensor<1x3x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_relu_fn + // CHECK-SAME: -> tensor<1x3x!quant.uniform> + %2 = call @quantized_dot_general_relu_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3, %2 : tensor<1x3xf32>, tensor<1x3x!quant.uniform> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_relu_fn + func.func private @quantized_dot_general_relu_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Do not merge when fusion and dequantize is already merged. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @no_merge_fusion_already_merged + func.func private @no_merge_fusion_already_merged(%arg0: tensor<1x4xf32>) -> (tensor<1x3xf32>) { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_fn + // CHECK-SAME: -> tensor<1x3xf32> + %2 = call @quantized_dot_general_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_fn + func.func private @quantized_dot_general_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_dequantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} + +// ----- + +// Do not merge when function is not quantized function. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_relu_fusion + func.func private @merge_relu_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @some_func + // CHECK-SAME: -> tensor<1x3x!quant.uniform> + %2 = call @some_func(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @some_func + func.func private @some_func( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Do not merge when the quantized fusion is invalid. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_relu_fusion + func.func private @merge_relu_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_relu_fn + // CHECK-SAME: -> tensor<1x3x!quant.uniform> + %2 = call @quantized_dot_general_relu_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_relu_fn + func.func private @quantized_dot_general_relu_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + %0 = stablehlo.constant() {value = dense<2> : tensor<1x3xi8>} : () -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir index 06edf90896e5ca..5e75126225244a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir @@ -11,7 +11,7 @@ module attributes {tf_saved_model.semantics} { %2 = "quantfork.dcast"(%1) : (tensor<4x3x!quant.uniform:f32:1, {5.000000e-03, 5.000000e-03, 5.000000e-03}>>) -> tensor<4x3xf32> %3 = "quantfork.qcast"(%arg0) {volatile} : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> %4 = "quantfork.dcast"(%3) : (tensor<1x4x!quant.uniform>) -> tensor<1x4xf32> - %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> %6 = "quantfork.qcast"(%5) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> %7 = "quantfork.dcast"(%6) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> return %7 : tensor<1x3xf32> @@ -22,7 +22,10 @@ module attributes {tf_saved_model.semantics} { // CHECK: %[[CONST_0:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> // CHECK-DAG: %[[QCAST_0:.+]] = "quantfork.qcast"(%[[CONST_0]]) {volatile} : (tensor<4x3xf32>) -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> // CHECK-DAG: %[[QCAST_1:.+]] = "quantfork.qcast"(%[[ARG_0]]) {volatile} : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[QCAST_1]], %[[QCAST_0]]) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[QCAST_1]], %[[QCAST_0]]) +// Test that the `Method` has been copied over. +// CHECK-SAME: {_quantization_method = "static_range_ptq { }"} +// CHECK-SAME: : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> // CHECK: %[[DCAST_0:.+]] = "quantfork.dcast"(%[[CALL_0]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK: return @@ -40,7 +43,7 @@ module attributes {tf_saved_model.semantics} { // CHECK-LABEL: quantize_simple_xla_call_module_no_operand func.func private @quantize_simple_xla_call_module_no_operand() -> tensor<1x3xf32> { - %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> return %2 : tensor<1x3xf32> @@ -63,7 +66,7 @@ module attributes {tf_saved_model.semantics} { %4 = "quantfork.dcast"(%3) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // expected-error @+2 {{Failed to find a valid entry function}} // expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}} - %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> %6 = "quantfork.qcast"(%5) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> %7 = "quantfork.dcast"(%6) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> return %7 : tensor<1x3xf32> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_op_with_region.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_op_with_region.mlir index 04104c308a3b3d..d94e1ca3787a3c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_op_with_region.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_op_with_region.mlir @@ -32,7 +32,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p %5 = "quantfork.dcast"(%4) : (tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>>) -> tensor<2x3x1024x3xf32> %6 = "quantfork.qcast"(%arg0) {volatile} : (tensor<2x3x1x1024xf32>) -> tensor<2x3x1x1024x!quant.uniform> %7 = "quantfork.dcast"(%6) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> - %8 = "tf.XlaCallModule"(%7, %5) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + %8 = "tf.XlaCallModule"(%7, %5) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> %9 = "quantfork.qcast"(%8) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> %10 = "quantfork.dcast"(%9) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> %11 = "stablehlo.reduce_window"(%10, %3) ({ @@ -98,7 +98,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p %8 = "quantfork.dcast"(%7) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> %9 = "quantfork.qcast"(%1) {volatile} : (tensor<2x3x1024x3xf32>) -> tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>> %10 = "quantfork.dcast"(%9) : (tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>>) -> tensor<2x3x1024x3xf32> - %11 = "tf.XlaCallModule"(%8, %10) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + %11 = "tf.XlaCallModule"(%8, %10) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> %12 = "quantfork.qcast"(%11) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> %13 = "quantfork.dcast"(%12) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> return %13 : tensor<2x3x1x3xf32> @@ -150,7 +150,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p %5 = "quantfork.dcast"(%4) : (tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>>) -> tensor<2x3x1024x3xf32> %6 = "quantfork.qcast"(%arg0) {volatile} : (tensor<2x3x1x1024xf32>) -> tensor<2x3x1x1024x!quant.uniform> %7 = "quantfork.dcast"(%6) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> - %8 = "tf.XlaCallModule"(%7, %5) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + %8 = "tf.XlaCallModule"(%7, %5) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> %9 = "quantfork.qcast"(%8) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> %10 = "quantfork.dcast"(%9) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> %11 = stablehlo.reshape %10 : (tensor<2x3x1x3xf32>) -> tensor<2x3x3xf32> @@ -223,7 +223,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p %11 = "quantfork.dcast"(%10) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> %12 = "quantfork.qcast"(%1) {volatile} : (tensor<2x3x1024x3xf32>) -> tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>> %13 = "quantfork.dcast"(%12) : (tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>>) -> tensor<2x3x1024x3xf32> - %14 = "tf.XlaCallModule"(%11, %13) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + %14 = "tf.XlaCallModule"(%11, %13) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> %15 = "quantfork.qcast"(%14) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> %16 = "quantfork.dcast"(%15) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> return %16 : tensor<2x3x1x3xf32> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_same_scale.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_same_scale.mlir index 9be0add0ba4551..25aab3044a3496 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_same_scale.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_same_scale.mlir @@ -16,7 +16,7 @@ module attributes {tf_saved_model.semantics} { %1 = "quantfork.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> %3 = "quantfork.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> - %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> %5 = "quantfork.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> %6 = "quantfork.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> %7 = stablehlo.reshape %6 : (tensor<1x3xf32>) -> tensor<3x1xf32> @@ -58,7 +58,7 @@ module attributes {tf_saved_model.semantics} { %1 = "quantfork.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> %3 = "quantfork.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> - %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> %5 = "quantfork.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> %6 = "quantfork.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> %7 = stablehlo.reshape %6 : (tensor<1x3xf32>) -> tensor<3x1xf32> @@ -133,7 +133,7 @@ module attributes {tf_saved_model.semantics} { %6 = "quantfork.dcast"(%5) : (tensor<4x2x!quant.uniform>) -> tensor<4x2xf32> %7 = "quantfork.qcast"(%arg2) {volatile} : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>> %8 = "quantfork.dcast"(%7) : (tensor<2x5x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x5xf32> - %9 = "tf.XlaCallModule"(%6, %8) {Sout = [#tf_type.shape<4x5>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<4x2xf32>, tensor<2x5xf32>) -> tensor<4x5xf32> + %9 = "tf.XlaCallModule"(%6, %8) {Sout = [#tf_type.shape<4x5>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<4x2xf32>, tensor<2x5xf32>) -> tensor<4x5xf32> %10 = "quantfork.qcast"(%9) {volatile} : (tensor<4x5xf32>) -> tensor<4x5x!quant.uniform> %11 = "quantfork.dcast"(%10) : (tensor<4x5x!quant.uniform>) -> tensor<4x5xf32> return %11 : tensor<4x5xf32> @@ -173,7 +173,7 @@ module attributes {tf_saved_model.semantics} { %1 = "quantfork.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> %3 = "quantfork.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> - %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> %5 = "quantfork.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> %6 = "quantfork.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> %7 = "quantfork.qcast"(%arg2) {volatile} : (tensor) -> tensor> @@ -218,7 +218,7 @@ module attributes {tf_saved_model.semantics} { %1 = "quantfork.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> %3 = "quantfork.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> - %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> %5 = "quantfork.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> %6 = "quantfork.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> %7 = "quantfork.qcast"(%arg3) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> @@ -260,7 +260,7 @@ module attributes {tf_saved_model.semantics} { %1 = "quantfork.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> %3 = "quantfork.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> - %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> %5 = "quantfork.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> %6 = "quantfork.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> %7 = stablehlo.broadcast_in_dim %6, dims = [2, 1] : (tensor<1x3xf32>) -> tensor<2x3x2xf32> @@ -302,7 +302,7 @@ module attributes {tf_saved_model.semantics} { %1 = "quantfork.dcast"(%0) : (tensor<3x4x5x!quant.uniform>) -> tensor<3x4x5xf32> %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor<3x5x2xf32>) -> tensor<3x5x2x!quant.uniform:f32, 6.000000e-03:13>> %3 = "quantfork.dcast"(%2) : (tensor<3x5x2x!quant.uniform:f32, 6.000000e-03:13>>) -> tensor<3x5x2xf32> - %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x5xf32>, tensor<3x5x2xf32>) -> tensor<3x4x2xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x5xf32>, tensor<3x5x2xf32>) -> tensor<3x4x2xf32> %5 = "quantfork.qcast"(%4) {volatile} : (tensor<3x4x2xf32>) -> tensor<3x4x2x!quant.uniform> %6 = "quantfork.dcast"(%5) : (tensor<3x4x2x!quant.uniform>) -> tensor<3x4x2xf32> %7 = "stablehlo.gather"(%6, %arg2) { @@ -350,7 +350,7 @@ module attributes {tf_saved_model.semantics} { %1 = "quantfork.dcast"(%0) : (tensor<3x2x!quant.uniform>) -> tensor<3x2xf32> %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>> %3 = "quantfork.dcast"(%2) : (tensor<2x4x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x4xf32> - %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x2xf32>, tensor<2x4xf32>) -> tensor<3x4xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x2xf32>, tensor<2x4xf32>) -> tensor<3x4xf32> %5 = "quantfork.qcast"(%4) {volatile} : (tensor<3x4xf32>) -> tensor<3x4x!quant.uniform> %6 = "quantfork.dcast"(%5) : (tensor<3x4x!quant.uniform>) -> tensor<3x4xf32> %7 = stablehlo.slice %6 [1:3, 2:4] : (tensor<3x4xf32>) -> tensor<2x2xf32> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir index 6db474de676ccc..81e8b4bde5e13e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir @@ -8,7 +8,7 @@ module attributes {tf_saved_model.semantics} { %cst = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> %0 = "quantfork.qcast"(%cst) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> %1 = "quantfork.dcast"(%0) : (tensor<2x3x!quant.uniform>) -> tensor<2x3xf32> - %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> return %2 : tensor<1x3xf32> } @@ -22,7 +22,8 @@ module attributes {tf_saved_model.semantics} { // CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> // CHECK: %[[CST:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> // CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> -// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[Q]]) : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[Q]]) +// CHECK-SAME: {_quantization_method = "weight_only_ptq { }"} : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK: return %[[CALL]] // CHECK: quantized_dot_general_fn @@ -41,7 +42,7 @@ module attributes {tf_saved_model.semantics} { %cst = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> %0 = "quantfork.qcast"(%cst) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> %1 = "quantfork.dcast"(%0) : (tensor<2x3x3x2x!quant.uniform>) -> tensor<2x3x3x2xf32> - %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> return %2 : tensor<1x3x4x2xf32> } @@ -55,7 +56,7 @@ module attributes {tf_saved_model.semantics} { // CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x4x3xf32> // CHECK: %[[CST:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> // CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> -// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[Q]]) : (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[Q]]) {_quantization_method = "weight_only_ptq { }"} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> // CHECK: return %[[CALL]] // CHECK: quantized_conv_fn diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir index f9fa9ce5f60b87..09f002559b7830 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir @@ -9,7 +9,7 @@ module attributes {tf_saved_model.semantics} { func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> return %2 : tensor<1x3xf32> } @@ -19,14 +19,14 @@ module attributes {tf_saved_model.semantics} { // CHECK: func.func private @quantize_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} // CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>) -> tensor<1x3x!quant.uniform> // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> // CHECK-PER-TENSOR: func.func private @quantize_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} // CHECK-PER-TENSOR: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> // CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> // CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> @@ -57,14 +57,14 @@ module attributes {tf_saved_model.semantics} { func.func private @quantize_dot_general_batch_per_tensor_quantized_fn(%arg0: tensor<2x2x2xf32>) -> tensor<2x2x3xf32> attributes {tf._original_func_name = "main_0"} { %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x2x3xf32>} : () -> tensor<2x2x3xf32> %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32> - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<2x2x2xf32>, tensor<2x2x3xf32>) -> tensor<2x2x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<2x2x2xf32>, tensor<2x2x3xf32>) -> tensor<2x2x3xf32> %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<2x2x3xf32>) -> tensor<2x2x3xf32> return %2 : tensor<2x2x3xf32> } // CHECK: func.func private @quantize_dot_general_batch_per_tensor_quantized_fn(%[[ARG_0:.+]]: tensor<2x2x2xf32>) -> tensor<2x2x3xf32> attributes {tf._original_func_name = "main_0"} // CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x2x3xi8>} : () -> tensor<2x2x3x!quant.uniform> // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<2x2x2xf32>) -> tensor<2x2x2x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<2x2x2x!quant.uniform>, tensor<2x2x3x!quant.uniform) -> tensor<2x2x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<2x2x2x!quant.uniform>, tensor<2x2x3x!quant.uniform) -> tensor<2x2x3x!quant.uniform> // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<2x2x3x!quant.uniform) -> tensor<2x2x3xf32> // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<2x2x3xf32> @@ -83,7 +83,7 @@ module attributes {tf_saved_model.semantics} { %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x3xf32>} : () -> tensor<1x3xf32> %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_same_shape_fn, _original_entry_function = "composite_dot_general_with_bias_same_shape_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_same_shape_fn, _original_entry_function = "composite_dot_general_with_bias_same_shape_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> return %2 : tensor<1x3xf32> } @@ -91,7 +91,7 @@ module attributes {tf_saved_model.semantics} { // CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> // CHECK: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x3xi32>} : () -> tensor<1x3x!quant.uniform // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_same_shape_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_same_shape_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> @@ -99,7 +99,7 @@ module attributes {tf_saved_model.semantics} { // CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> // CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x3xi32>} : () -> tensor<1x3x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_same_shape_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_same_shape_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> // CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> @@ -132,7 +132,7 @@ module attributes {tf_saved_model.semantics} { %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<3xf32>} : () -> tensor<3xf32> %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape], _entry_function = @composite_dot_general_with_bias_dynamic_fn, _original_entry_function = "composite_dot_general_with_bias_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3xf32>, tensor<3xf32>) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape], _entry_function = @composite_dot_general_with_bias_dynamic_fn, _original_entry_function = "composite_dot_general_with_bias_dynamic_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3xf32>, tensor<3xf32>) -> tensor %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor return %2 : tensor } @@ -140,7 +140,7 @@ module attributes {tf_saved_model.semantics} { // CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> // CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<3xi32>} : () -> tensor<3x!quant.uniform // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> -// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>, tensor<3x!quant.uniform) -> tensor +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>, tensor<3x!quant.uniform) -> tensor // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor @@ -148,7 +148,7 @@ module attributes {tf_saved_model.semantics} { // CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> // CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<3xi32>} : () -> tensor<3x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> -// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<3x!quant.uniform) -> tensor +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<3x!quant.uniform) -> tensor // CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor // CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor @@ -221,14 +221,16 @@ module attributes {tf_saved_model.semantics} { // CHECK: func.func private @quantize_conv_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} // CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> // CHECK-PER-TENSOR: func.func private @quantize_conv_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} // CHECK-PER-TENSOR: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> -// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) +// CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> // CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> // CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> @@ -272,7 +274,7 @@ func.func @quantize_conv_fn_per_tensor(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3 version = 5 : i64, _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", - _quantization_method = "static_range_ptq {}", + _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "" @@ -286,7 +288,7 @@ func.func @quantize_conv_fn_per_tensor(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3 // CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> // CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform> // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> @@ -339,7 +341,8 @@ module attributes {tf_saved_model.semantics} { // CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> // CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<47978> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_1d_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_1d_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> @@ -347,7 +350,8 @@ module attributes {tf_saved_model.semantics} { // CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> // CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2xi32>} : () -> tensor<2x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> -// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_1d_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_1d_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> // CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> @@ -406,7 +410,8 @@ module attributes {tf_saved_model.semantics} { // CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> // CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> @@ -414,7 +419,8 @@ module attributes {tf_saved_model.semantics} { // CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> // CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> -// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> // CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> @@ -472,7 +478,8 @@ module attributes {tf_saved_model.semantics} { // CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> // CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor @@ -480,7 +487,8 @@ module attributes {tf_saved_model.semantics} { // CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> // CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> -// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER_TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor // CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor // CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor @@ -564,7 +572,8 @@ module attributes {tf_saved_model.semantics} { // CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> // CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor @@ -572,7 +581,8 @@ module attributes {tf_saved_model.semantics} { // CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> // CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> -// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> // CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor // CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor @@ -659,7 +669,8 @@ module attributes {tf_saved_model.semantics} { // CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> // CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform // CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> // CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor @@ -667,7 +678,8 @@ module attributes {tf_saved_model.semantics} { // CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> // CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform // CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> -// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER-TENSOR: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> // CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor // CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor @@ -715,12 +727,12 @@ module attributes {tf_saved_model.semantics} { // ----- -// Tests that XlaCallModule op is not quantized without the quantfork.stats ops. +// Tests that XlaCallModule op is not quantized and converted to func.call without the quantfork.stats ops. module attributes {tf_saved_model.semantics} { func.func private @not_quantized_without_stats_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> return %1 : tensor<1x3xf32> } // Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is @@ -728,8 +740,8 @@ module attributes {tf_saved_model.semantics} { // CHECK: func.func private @not_quantized_without_stats_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} // CHECK: %[[CONST_0:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> -// CHECK: %[[XLA_CALL_MODULE_0:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> -// CHECK: return %[[XLA_CALL_MODULE_0]] +// CHECK: %[[CALL:.+]] = call @composite_dot_general_fn(%[[ARG_0]], %[[CONST_0]]) : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> @@ -750,7 +762,7 @@ module attributes {tf_saved_model.semantics} { func.func private @quantize_gather_fn(%arg: tensor<3x4x2xf32>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} { %cst = "tf.Const"() {value = dense<1> : tensor<2x3x2xi32>} : () -> tensor<2x3x2xi32> %0 = "quantfork.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<3x4x2xf32>) -> tensor<3x4x2xf32> - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<2x3x2x2>], _entry_function = @composite_gather_fn, _original_entry_function = "composite_gather_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<2x3x2x2>], _entry_function = @composite_gather_fn, _original_entry_function = "composite_gather_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> %2 = "quantfork.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<2x3x2x2xf32>) -> tensor<2x3x2x2xf32> return %2 : tensor<2x3x2x2xf32> } @@ -758,7 +770,7 @@ module attributes {tf_saved_model.semantics} { // calls the quantized entry function. // CHECK: %[[CONST:.+]] = stablehlo.constant dense<{{.*}}> : tensor<2x3x2xi32> // CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<3x4x2xf32>) -> tensor<3x4x2x!quant.uniform> -// CHECK: %[[CALL:.+]] = call @quantized_gather_fn(%[[UNIFORM_QUANTIZE]], %[[CONST]]) : (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_gather_fn(%[[UNIFORM_QUANTIZE]], %[[CONST]]) {_quantization_method = "static_range_ptq { }"} : (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> // CHECK: %[[UNIFORM_DEQUANTIZE:.+]] = stablehlo.uniform_dequantize %[[CALL]] : (tensor<2x3x2x2x!quant.uniform) -> tensor<2x3x2x2xf32> // CHECK: return %[[UNIFORM_DEQUANTIZE]] : tensor<2x3x2x2xf32> @@ -776,5 +788,52 @@ module attributes {tf_saved_model.semantics} { return %0 : tensor<2x3x2x2xf32> } // CHECK: %[[GATHER:.+]] = "stablehlo.gather"(%[[ARG_0]], %[[ARG_1]]) {{.*}} : (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> -// CHECK: return %[[GATHER]] : tensor<2x3x2x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[GATHER]] : tensor<2x3x2x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor<2x3x2x2x!quant.uniform> +} + +// ----- + +// Tests that a basic `stablehlo.add` and a fused `stablehlo.dot_general` +// are properly quantized. + +module attributes {tf_saved_model.semantics} { +// CHECK: func.func private @quantize_add_fn(%[[ARG:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} + func.func private @quantize_add_fn(%arg: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst_0 = "tf.Const"() {value = dense<1.00000000e-1> : tensor<1x2xf32>} : () -> tensor<1x2xf32> + %cst_1 = "tf.Const"() {value = dense<1.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantfork.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst_0) {Sout = [#tf_type.shape<1x2>], _entry_function = @composite_add_fn, _original_entry_function = "composite_add_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %3 = "quantfork.stats"(%2) {layerStats = dense<[5.00000000e-6, 6.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %4 = "tf.XlaCallModule"(%3, %cst_1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantfork.stats"(%4) {layerStats = dense<[5.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %5 : tensor<1x3xf32> + } +// CHECK: %[[CONST:.+]] = stablehlo.constant() {value = dense<127> : tensor<1x2xi8>} : () -> tensor<1x2x!quant.uniform> +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}>> +// CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_add_fn(%[[UNIFORM_QUANTIZE]], %[[CONST]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE:.+]] = stablehlo.uniform_dequantize %[[CALL]] : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[UNIFORM_DEQUANTIZE]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + +// CHECK: func.func private @quantized_add_fn(%[[ARG_0:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_add_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.add %arg0, %arg1 : tensor<1x2xf32> + return %0 : tensor<1x2xf32> + } +// CHECK: %[[ADD:.+]] = stablehlo.add %arg0, %arg1 : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: return %[[ADD]] : tensor<1x2x!quant.uniform> + +// CHECK: func.func private @quantized_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// CHECK: %[[DOT_GENERAL:.+]] = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1,{{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE]] : tensor<1x3x!quant.uniform> } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_all_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_all_ops.mlir deleted file mode 100644 index 72851d92b64b75..00000000000000 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_all_ops.mlir +++ /dev/null @@ -1,46 +0,0 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ -// RUN: -stablehlo-quantize-composite-functions=enable-full-int-quantization=true | FileCheck --check-prefix=CHECK-FULL-INT %s - -// Tests that a basic `stablehlo.add` and a fused `stablehlo.dot_general` -// are properly quantized. - -module attributes {tf_saved_model.semantics} { -// CHECK-FULL-INT: func.func private @quantize_add_fn(%[[ARG:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} - func.func private @quantize_add_fn(%arg: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst_0 = "tf.Const"() {value = dense<1.00000000e-1> : tensor<1x2xf32>} : () -> tensor<1x2xf32> - %cst_1 = "tf.Const"() {value = dense<1.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %0 = "quantfork.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> - %1 = "tf.XlaCallModule"(%0, %cst_0) {Sout = [#tf_type.shape<1x2>], _entry_function = @composite_add_fn, _original_entry_function = "composite_add_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> - %2 = "quantfork.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> - %3 = "quantfork.stats"(%2) {layerStats = dense<[5.00000000e-6, 6.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> - %4 = "tf.XlaCallModule"(%3, %cst_1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - %5 = "quantfork.stats"(%4) {layerStats = dense<[5.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> - return %5 : tensor<1x3xf32> - } -// CHECK-FULL-INT: %[[CONST:.+]] = stablehlo.constant() {value = dense<127> : tensor<1x2xi8>} : () -> tensor<1x2x!quant.uniform> -// CHECK-FULL-INT: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}>> -// CHECK-FULL-INT: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK-FULL-INT: %[[CALL:.+]] = call @quantized_add_fn(%[[UNIFORM_QUANTIZE]], %[[CONST]]) : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> -// CHECK-FULL-INT: %[[UNIFORM_DEQUANTIZE:.+]] = stablehlo.uniform_dequantize %[[CALL]] : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK-FULL-INT: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[UNIFORM_DEQUANTIZE]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK-FULL-INT: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> -// CHECK-FULL-INT: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> -// CHECK-FULL-INT: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> - -// CHECK-FULL-INT: func.func private @quantized_add_fn(%[[ARG_0:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> attributes {_from_xla_call_module} - func.func private @composite_add_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.add %arg0, %arg1 : tensor<1x2xf32> - return %0 : tensor<1x2xf32> - } -// CHECK-FULL-INT: %[[ADD:.+]] = stablehlo.add %arg0, %arg1 : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> -// CHECK-FULL-INT: return %[[ADD]] : tensor<1x2x!quant.uniform> - -// CHECK-FULL-INT: func.func private @quantized_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} - func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %0 : tensor<1x3xf32> - } -// CHECK-FULL-INT: %[[DOT_GENERAL:.+]] = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1,{{.*}}>>) -> tensor<1x3x!quant.uniform> -// CHECK-FULL-INT: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> -// CHECK-FULL-INT: return %[[UNIFORM_QUANTIZE]] : tensor<1x3x!quant.uniform> -} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir index dce15fe07760e2..b96cb15039d763 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir @@ -2,12 +2,12 @@ // RUN: -stablehlo-quantize-composite-functions=enable-weight-only=true | FileCheck --check-prefix=CHECK %s // Test that weight-only quantized dot_general op is produced when -// enable-weight-only is set to true. +// weight_only_ptq is provided. module attributes {tf_saved_model.semantics} { func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> - %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> return %1 : tensor<1x3xf32> } @@ -20,7 +20,7 @@ module attributes {tf_saved_model.semantics} { // CHECK-LABEL: quantize_dot_general_fn // CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> // CHECK: %[[CST:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform> -// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[CST]]) : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq { }"} : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK: return %[[CALL]] // CHECK: quantized_dot_general_fn @@ -31,13 +31,13 @@ module attributes {tf_saved_model.semantics} { // ----- -// Test that hybrid quantized convolution op is produced when enable-weight-only -// is set to true. +// Test that hybrid quantized convolution op is produced when weight_only_ptq is +// provided. module attributes {tf_saved_model.semantics} { func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> - %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> return %1 : tensor<1x3x4x2xf32> } @@ -50,7 +50,7 @@ module attributes {tf_saved_model.semantics} { // CHECK-LABEL: quantize_conv_fn // CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x4x3xf32> // CHECK: %[[CST:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform> -// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[CST]]) : (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq { }"} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> // CHECK: return %[[CALL]] // CHECK: quantized_conv_fn diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir index eccf25e1acbfed..02e1c5e9923915 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir @@ -22,25 +22,25 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p func.func @main(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x64xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> %1 = stablehlo.constant dense<1.000000e+03> : tensor<1x3xf32> - %2 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - %3 = "tf.XlaCallModule"(%2, %0, %1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> - %4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> + %2:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %3 = "tf.XlaCallModule"(%2#0, %0, %1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %4:4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) %5 = stablehlo.constant dense<1.000000e+03> : tensor<3x64xf32> %6 = stablehlo.constant dense<1.000000e+03> : tensor<1x64xf32> - %7 = "tf.CustomAggregator"(%4) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> - %8 = "tf.XlaCallModule"(%7, %5, %6) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x3xf32>, tensor<3x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> - %9 = "tf.CustomAggregator"(%6) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x64xf32>) -> tensor<1x64xf32> - return %9 : tensor<1x64xf32> + %7:4 = "tf.CustomAggregator"(%4#0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + %8 = "tf.XlaCallModule"(%7#0, %5, %6) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x3xf32>, tensor<3x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> + %9:4 = "tf.CustomAggregator"(%6) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x64xf32>) -> (tensor<1x64xf32>, tensor, tensor, tensor<*xi64>) + return %9#0 : tensor<1x64xf32> } // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}, module = "", platforms = ["CPU", "TPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_1 - // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} // CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable"} - // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_0]]) + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_0]]) // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}, module = "", platforms = ["CPU", "TPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0 - // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]] = "tf.CustomAggregator"(%[[CUSTOM_AGGREGATOR_1]]) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> + // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[CUSTOM_AGGREGATOR_1]]) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} // CHECK: %[[XLA_CALL_MODULE_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _tfl_quant_trait = "fully_quantizable"} - // CHECK: %[[CUSTOM_AGGREGATOR_3:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_1:.*]]) + // CHECK: %[[CUSTOM_AGGREGATOR_3:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_1:.*]]) // CHECK: return %[[CUSTOM_AGGREGATOR_3]] : tensor<1x64xf32> // CHECK: } @@ -111,16 +111,16 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: @serving_default func.func @serving_default(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> - %1 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - %2 = "tf.XlaCallModule"(%1, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> - %3 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> - return %3 : tensor<1x3xf32> + %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %2 = "tf.XlaCallModule"(%1#0, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + return %3#0 : tensor<1x3xf32> } // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}, module = "", platforms = ["CPU", "TPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}} - // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" - // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: return %[[CUSTOM_AGGREGATOR_1]] // CHECK: } @@ -143,16 +143,16 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK-LABEL: @random_name func.func @random_name(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> - %1 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - %2 = "tf.XlaCallModule"(%1, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> - %3 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> - return %3 : tensor<1x3xf32> + %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %2 = "tf.XlaCallModule"(%1#0, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + return %3#0 : tensor<1x3xf32> } // CHECK: %[[CONSTANT:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> - // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[XLA_CALL_MODULE_EXTRACTED_FROM_SUBGRAPH:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" - // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: return %[[CUSTOM_AGGREGATOR_1]] // CHECK: } @@ -185,19 +185,19 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: @serving_default func.func @serving_default(%arg0: tensor<1024x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1024x3xf32> {tf_saved_model.index_path = ["output1"]}, tensor<1024x3xf32> {tf_saved_model.index_path = ["output2"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> - %1 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> - %2 = "tf.XlaCallModule"(%1, %0) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> - %3 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x3xf32>) -> tensor<1024x3xf32> + %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %2 = "tf.XlaCallModule"(%1#0, %0) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> + %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) %4 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> - %5 = stablehlo.add %3, %4 : tensor<1024x3xf32> - %6 = stablehlo.multiply %3, %0 : tensor<1024x3xf32> + %5 = stablehlo.add %3#0, %4 : tensor<1024x3xf32> + %6 = stablehlo.multiply %3#0, %0 : tensor<1024x3xf32> return %5, %6 : tensor<1024x3xf32>, tensor<1024x3xf32> } // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 - // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" - // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: %[[SUBGRAPH_2:.*]]:2 = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>, #tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 // CHECK: return %[[SUBGRAPH_2]]#0, %[[SUBGRAPH_2]]#1 // CHECK: } @@ -235,18 +235,18 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p %0 = stablehlo.constant dense<1.000000e+03> : tensor<3x11xf32> // %1 is large enough that it won't be duplicated. %1 = stablehlo.constant dense<1.000000e+01> : tensor<3x11xf32> - %2 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<3x3xf32>) -> tensor<3x3xf32> - %3 = "tf.XlaCallModule"(%2, %0) {Sout = [#tf_type.shape<3x11>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x3xf32>, tensor<3x11xf32>) -> tensor<3x11xf32> - %4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<3x11xf32>) -> tensor<3x11xf32> - %5 = stablehlo.add %4, %1 : tensor<3x11xf32> + %2:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor, tensor, tensor<*xi64>) + %3 = "tf.XlaCallModule"(%2#0, %0) {Sout = [#tf_type.shape<3x11>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x3xf32>, tensor<3x11xf32>) -> tensor<3x11xf32> + %4:4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<3x11xf32>) -> (tensor<3x11xf32>, tensor, tensor, tensor<*xi64>) + %5 = stablehlo.add %4#0, %1 : tensor<3x11xf32> %6 = stablehlo.multiply %5, %1 : tensor<3x11xf32> return %6 : tensor<3x11xf32> } // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<3x11>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 - // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<3x11>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" - // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: %[[SUBGRAPH_2:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]]) <{Sout = [#tf_type.shape<3x11>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_0 // CHECK: return %[[SUBGRAPH_2]] // CHECK: } @@ -293,16 +293,16 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p %4 = stablehlo.compare EQ, %3, %2, NOTYPE : (tensor<1024x3xf32>, tensor<1024x3xf32>) -> tensor<1024x3xi1> stablehlo.custom_call @shape_assertion(%4) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<1024x3xi1>) -> () %5 = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> - %6 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> - %7 = "tf.XlaCallModule"(%6, %5) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> - %8 = "tf.CustomAggregator"(%7) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x3xf32>) -> tensor<1024x3xf32> - %9 = stablehlo.add %8, %0 : tensor<1024x3xf32> + %6:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %7 = "tf.XlaCallModule"(%6#0, %5) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> + %8:4 = "tf.CustomAggregator"(%7) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %9 = stablehlo.add %8#0, %0 : tensor<1024x3xf32> return %9 : tensor<1024x3xf32> } // CHECK: %[[SUBGRAPH_0:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 - // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[SUBGRAPH_0]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" - // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 // CHECK: return %[[SUBGRAPH_1]] : tensor<1024x3xf32> // CHECK: } @@ -339,18 +339,18 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p %2 = stablehlo.remainder %0, %1 : tensor<1024x3xf32> %3 = "tf.Identity"(%2) {device = ""} : (tensor<1024x3xf32>) -> tensor<1024x3xf32> %4 = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> - %5 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> - %6 = "tf.XlaCallModule"(%5, %4) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> - %7 = "tf.CustomAggregator"(%6) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x3xf32>) -> tensor<1024x3xf32> - %8 = stablehlo.add %7, %0 : tensor<1024x3xf32> + %5:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %6 = "tf.XlaCallModule"(%5#0, %4) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> + %7:4 = "tf.CustomAggregator"(%6) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %8 = stablehlo.add %7#0, %0 : tensor<1024x3xf32> return %8 : tensor<1024x3xf32> } // CHECK: %[[SUBGRAPH_0:.*]]:2 = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>, #tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_2 // CHECK: %[[IDENTIFY:.*]] = "tf.Identity"(%[[SUBGRAPH_0]]#1) {device = ""} : (tensor<1024x3xf32>) -> tensor<1024x3xf32> // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 - // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" - // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: %[[SUBGRAPH_2:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_0]]#0) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 // CHECK: return %[[SUBGRAPH_2]] : tensor<1024x3xf32> // CHECK: } @@ -394,16 +394,16 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p %7 = stablehlo.compare EQ, %6, %2, NOTYPE : (tensor<1024x3xf32>, tensor<1024x3xf32>) -> tensor<1024x3xi1> stablehlo.custom_call @shape_assertion(%7) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<1024x3xi1>) -> () %8 = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> - %9 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> - %10 = "tf.XlaCallModule"(%9, %8) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> - %11 = "tf.CustomAggregator"(%10) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x3xf32>) -> tensor<1024x3xf32> - %12 = stablehlo.add %11, %0 : tensor<1024x3xf32> + %9:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %10 = "tf.XlaCallModule"(%9#0, %8) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> + %11:4 = "tf.CustomAggregator"(%10) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %12 = stablehlo.add %11#0, %0 : tensor<1024x3xf32> return %12 : tensor<1024x3xf32> } // CHECK: %[[SUBGRAPH_0:.*]]:2 = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>, #tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 - // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[SUBGRAPH_0]]#1) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" - // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_0]]#0) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 // CHECK: return %[[SUBGRAPH_1]] : tensor<1024x3xf32> // CHECK: } @@ -411,11 +411,16 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // ----- -// main function contains StatefulPartitionedCall ops which is used to preserve -// aliased functions. This test make sure stablehlo ops in each PartitionedCall -// functions are lifted. +// main function contains PartitionedCall and StatefulPartitionedCall ops which +// is used to preserve aliased functions. This test make sure stablehlo ops in +// each PartitionedCall functions are lifted. module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func private @_stablehlo_main_2 + // CHECK: stablehlo.multiply %arg1, %arg2 : tensor<3x3xf32> + // CHECK: return + // CHECK: } + // CHECK: func private @_stablehlo_main_1 // CHECK: stablehlo.add %arg1, %arg2 : tensor<3x3xf32> // CHECK: return @@ -435,13 +440,20 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p }> { _collective_manager_ids = [], device = "" } : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - return %2 : tensor<3x3xf32> + %3 = "tf.PartitionedCall"(%2, %1) <{ + config = "", config_proto = "", executor_type = "", f = @some_other_func + }> { + _collective_manager_ids = [], device = "" + } : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + return %3 : tensor<3x3xf32> } // CHECK: func.func @main - // CHECK: %[[INPUT:.*]]:2 = "tf.XlaCallModule"() + // CHECK: %[[INPUT:.*]]:3 = "tf.XlaCallModule"() // CHECK-SAME: _entry_function = @_stablehlo_main_0 - // CHECK: "tf.StatefulPartitionedCall"(%[[INPUT]]#0, %[[INPUT]]#1) + // CHECK: %[[ADD:.*]] = "tf.StatefulPartitionedCall"(%[[INPUT]]#1, %[[INPUT]]#2) // CHECK-SAME: f = @some_func + // CHECK: "tf.PartitionedCall"(%[[ADD]], %[[INPUT]]#0) + // CHECK-SAME: f = @some_other_func // CHECK: return func.func private @some_func(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> attributes {tf._noinline = true} { @@ -452,4 +464,13 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: tf.XlaCallModule // CHECK-SAME: _entry_function = @_stablehlo_main_1 // CHECK: return + + func.func private @some_other_func(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> attributes {tf._noinline = true} { + %0 = stablehlo.multiply %arg0, %arg1 : tensor<3x3xf32> + return %0 : tensor<3x3xf32> + } + // CHECK: func.func private @some_other_func + // CHECK: tf.XlaCallModule + // CHECK-SAME: _entry_function = @_stablehlo_main_2 + // CHECK: return } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/unfuse_mhlo_batch_norm.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/unfuse_mhlo_batch_norm.mlir index 7a57bdd64b2aab..0873b68b475b5a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/unfuse_mhlo_batch_norm.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/unfuse_mhlo_batch_norm.mlir @@ -13,10 +13,10 @@ func.func @unfuse_batch_norm( // CHECK-DAG: %[[EPS_BCAST:.+]] = mhlo.constant dense<1.001000e-05> : tensor<256xf32> // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> // CHECK-DAG: %[[STDDEV:.+]] = mhlo.sqrt %[[VARIANCE_EPS]] : tensor<256xf32> - // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32> // CHECK: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32> // CHECK: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/xla_call_module_to_call.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/xla_call_module_to_call.mlir new file mode 100644 index 00000000000000..f0330d0266d56d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/xla_call_module_to_call.mlir @@ -0,0 +1,23 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-xla-call-module-to-call | FileCheck %s + +// ----- + +// Tests composite tf.XlaCallModule is converted to func.call. + +module { + // CHECK-LABEL: func.func @main + func.func @main(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> { + // CHECK: call @composite_dot_general_fn_1 + // CHECK-SAME: (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + // CHECK-NOT: tf.XlaCallModule + %0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32> + %2 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + // CHECK-LABEL: func.func private @composite_dot_general_fn_1 + // CHECK-SAME: -> tensor<1x3xf32> + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index be0792ab76aff3..94dc1b1569620f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -411,6 +411,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:calibration_parameters", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:const_op_size", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:constant_fold", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:quantization_unit_loc", @@ -440,8 +441,6 @@ cc_library( "//tensorflow/core/platform:macros", "//tensorflow/core/platform:path", "//tensorflow/core/tpu:tpu_defs", - "//tensorflow/lite/kernels:padding", - "//tensorflow/lite/kernels/internal:quantization_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD index fe081684e55736..9ae8d6401afcd6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD @@ -79,7 +79,6 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":calibration_statistics_proto_cc", - "//tensorflow/core:framework", "@com_google_absl//absl/types:span", ], ) @@ -94,6 +93,7 @@ cc_library( ":calibration_statistics_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "@com_google_absl//absl/types:span", ], ) @@ -107,6 +107,7 @@ cc_library( ":calibration_statistics_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "@com_google_absl//absl/types:span", ], ) @@ -119,7 +120,9 @@ cc_library( ":calibration_statistics_collector_base", ":calibration_statistics_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:calibration_parameters", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "@com_google_absl//absl/types:span", ], ) @@ -169,7 +172,7 @@ tf_cc_test( srcs = ["calibration_statistics_collector_test.cc"], deps = [ ":calibration_statistics_collector_average_min_max", - ":calibration_statistics_collector_base", + ":calibration_statistics_collector_histogram", ":calibration_statistics_collector_min_max", ":calibration_statistics_proto_cc", "//tensorflow/core:test", @@ -203,8 +206,12 @@ tf_kernel_library( deps = [ ":calibrator_singleton_impl", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:calibration_parameters", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:errors", ], ) @@ -255,3 +262,48 @@ tf_python_pybind_extension( "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", ], ) + +tf_kernel_library( + name = "calibration_statistics_saver_op", + srcs = ["calibration_statistics_saver_op.cc"], + compatible_with = get_compatible_with_portable(), + visibility = [ + "//tensorflow:__pkg__", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:__pkg__", + ], + deps = [ + ":calibration_statistics_collector_average_min_max", + ":calibration_statistics_collector_base", + ":calibration_statistics_collector_histogram", + ":calibration_statistics_collector_min_max", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:logging", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:env", + ], +) + +tf_cc_test( + name = "calibration_statistics_saver_op_test", + srcs = ["calibration_statistics_saver_op_test.cc"], + deps = [ + ":calibration_statistics_proto_cc", + ":calibration_statistics_saver_op", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:status_matchers", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.proto b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.proto index eca79c4a141b3c..d4bc053a77cf32 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.proto +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.proto @@ -37,10 +37,15 @@ message CalibrationStatistics { // hist_freq[i] saves frequency of range [bins[i], bins[i + 1]). // bins[i] = lower_bound + bin_width * i // bins[i + 1] = lower_bound + bin_width * (i + 1) - repeated int64 hist_freq = 3; + repeated float hist_freq = 3; } MinMaxStatistics min_max_statistics = 1; AverageMinMaxStatistics average_min_max_statistics = 2; HistogramStatistics histogram_statistics = 3; } + +message CalibrationStatisticsMap { + // A map from the id of CustomAggregator op to its collected statistics. + map statistics = 1; +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.cc index cab8abdee4b5a7..e1faa17505edf5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.h" -#include -#include +#include #include +#include "absl/types/span.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" namespace tensorflow { @@ -30,21 +30,13 @@ void CalibrationStatisticsCollectorAverageMinMax::ClearData() { } void CalibrationStatisticsCollectorAverageMinMax::Collect( - const float *data, const unsigned int N) { - float input_min = std::numeric_limits::max(), - input_max = std::numeric_limits::lowest(); + const float min, const float max, absl::Span histogram) { + const float current_min_sum = average_min_max_statistics_.min_sum(); + const float current_max_sum = average_min_max_statistics_.max_sum(); + const int current_num_samples = average_min_max_statistics_.num_samples(); - for (int i = 0; i < N; ++i) { - input_min = std::min(input_min, data[i]); - input_max = std::max(input_max, data[i]); - } - - float current_min_sum = average_min_max_statistics_.min_sum(); - float current_max_sum = average_min_max_statistics_.max_sum(); - int current_num_samples = average_min_max_statistics_.num_samples(); - - average_min_max_statistics_.set_min_sum(current_min_sum + input_min); - average_min_max_statistics_.set_max_sum(current_max_sum + input_max); + average_min_max_statistics_.set_min_sum(current_min_sum + min); + average_min_max_statistics_.set_max_sum(current_max_sum + max); average_min_max_statistics_.set_num_samples(current_num_samples + 1); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.h b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.h index 317b96ea423cb7..f6a5da84f1d675 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h" @@ -37,7 +38,8 @@ class CalibrationStatisticsCollectorAverageMinMax void ClearData() override; - void Collect(const float *data, unsigned int N) override; + void Collect(float min, float max, + absl::Span histogram) override; std::optional GetStatistics() const override; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h index 26417a1a6dae4d..9ce6a81930a7d6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h @@ -15,12 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_BASE_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_BASE_H_ +#include #include -#include #include "absl/types/span.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" -#include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace calibrator { @@ -30,25 +29,11 @@ namespace calibrator { // statistics based on the calibration methods. class CalibrationStatisticsCollectorBase { public: - // Collect data for calibration using float vector. - // It internally calls private method Collect(float*, unsigned int) - void Collect(const std::vector& data_vec) { - Collect(data_vec.data(), data_vec.size()); - } - // Collect data for calibration using absl::Span. - // It internally calls private method Collect(float*, unsigned int) - void Collect(absl::Span data_span) { - Collect(data_span.data(), data_span.size()); - } - // Collect data for calibration using Tensor - // It internally calls private method Collect(float*, unsigned int) - void Collect(const Tensor& data_tensor) { - auto data_flat = data_tensor.flat(); - Collect(data_flat.data(), data_flat.size()); - } + // Collect data for calibration. + virtual void Collect(float min, float max, + absl::Span histogram) = 0; virtual void ClearData() = 0; - virtual void Collect(const float* data, unsigned int N) = 0; // Return the statistics needed for a given calibration method. virtual std::optional GetStatistics() const = 0; virtual ~CalibrationStatisticsCollectorBase() = default; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.cc index 12e9f19dff2cab..1ad5bf1aba2091 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.cc @@ -16,60 +16,94 @@ limitations under the License. #include #include +#include #include #include +#include +#include "absl/types/span.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" namespace tensorflow { namespace calibrator { - -void CalibrationStatisticsCollectorHistogram::ClearData() { - num_bins_ = 256; - bin_width_ = 0; - hist_freq_.resize(num_bins_, 0); +namespace { + +using ::stablehlo::quantization::CalculateBinIndex; +using ::stablehlo::quantization::CalculateBinWidth; +using ::stablehlo::quantization::CalculateLowerBound; + +// Gets the histogram frequencies for the given range. +float GetRangeFrequencies(absl::Span histogram, + const float bin_width, const float lower_bound, + const float range_start, const float range_end) { + float freq_sum = 0.f; + for (float range = std::max(range_start, lower_bound); range < range_end; + range += bin_width) { + const int32_t idx = CalculateBinIndex(range, lower_bound, bin_width); + if (idx >= histogram.size()) break; + + // If the range is smaller than bin width, add the proportional value of + // that bin. + const float proportion = std::min(range_end - range, bin_width) / bin_width; + freq_sum += histogram[idx] * proportion; + } + return freq_sum; } -void CalibrationStatisticsCollectorHistogram::Collect(const float *data, - const unsigned int N) { - if (N == 0) return; +} // namespace - // When histogram is not initialized. - if (bin_width_ == 0) { - hist_freq_.resize(num_bins_, 0); - auto minmax = std::minmax_element(data, data + N); - - // The min and max of the first data will be the range of the histogram. - float min_value = std::floor(*minmax.first); - float max_value = std::ceil(*minmax.second); +void CalibrationStatisticsCollectorHistogram::ClearData() { + hist_freq_.clear(); +} - // The bin width is (max - min) divided by num_bins. - bin_width_ = (max_value - min_value) / num_bins_; +void CalibrationStatisticsCollectorHistogram::Collect( + const float min, const float max, absl::Span histogram) { + if (histogram.empty()) return; - // The lower bound is min value of data. - lower_bound_ = min_value; + // Reconstruct the bin width, lower and upper bound from the collected data. + const float collected_bin_width = + CalculateBinWidth(min, max, histogram.size()); + const float collected_lower_bound = + CalculateLowerBound(min, collected_bin_width); + const float collected_upper_bound = + std::ceil(max / collected_bin_width) * collected_bin_width; - // This is the worst case of first initialization, so it returns - // instantly. 1e-9 is threshold. - if (std::abs(bin_width_) < 1e-9) return; + // When histogram is not initialized. + if (hist_freq_.empty()) { + bin_width_ = collected_bin_width; + lower_bound_ = collected_lower_bound; } - for (int i = 0; i < N; ++i) { - int idx = GetHistogramIndex(data[i]); - hist_freq_[idx]++; + const auto [lower_idx, upper_idx] = + ExpandHistogramIfNeeded(collected_lower_bound, collected_upper_bound); + for (int32_t idx = lower_idx; idx <= upper_idx; ++idx) { + // Calculate the range covered by this index then add with the collected + // frequency associated to that range. + const float range_start = lower_bound_ + idx * bin_width_; + hist_freq_[idx] += GetRangeFrequencies(histogram, collected_bin_width, + collected_lower_bound, range_start, + range_start + bin_width_); } } std::optional CalibrationStatisticsCollectorHistogram::GetStatistics() const { - if (bin_width_ == 0) return std::nullopt; + if (hist_freq_.empty()) return std::nullopt; CalibrationStatistics::HistogramStatistics hist_stats; + // Skip trailing zeros in the histogram. + int32_t real_size = hist_freq_.size(); + for (; real_size > 0; --real_size) { + if (hist_freq_[real_size - 1] != 0) break; + } + hist_stats.set_lower_bound(lower_bound_); hist_stats.set_bin_width(bin_width_); - hist_stats.mutable_hist_freq()->Assign(hist_freq_.begin(), hist_freq_.end()); + hist_stats.mutable_hist_freq()->Assign(hist_freq_.begin(), + hist_freq_.begin() + real_size); CalibrationStatistics statistics; statistics.mutable_histogram_statistics()->CopyFrom(hist_stats); @@ -77,28 +111,23 @@ CalibrationStatisticsCollectorHistogram::GetStatistics() const { return statistics; } -int CalibrationStatisticsCollectorHistogram::ExpandHistogramIfNeeded(int idx) { - // If idx < 0, then expand the histogram to the left. - if (idx < 0) { - hist_freq_.insert(hist_freq_.begin(), -idx, 0); - lower_bound_ -= bin_width_ * (-idx); - idx = 0; +std::pair +CalibrationStatisticsCollectorHistogram::ExpandHistogramIfNeeded( + const float lower_bound, const float upper_bound) { + int32_t lower_idx = CalculateBinIndex(lower_bound, lower_bound_, bin_width_); + // If lower_idx < 0, then expand the histogram to the left. + if (lower_idx < 0) { + hist_freq_.insert(hist_freq_.begin(), -lower_idx, 0); + lower_bound_ -= bin_width_ * (-lower_idx); + lower_idx = 0; } - // If idx >= hist_freq_.size(), then expand the histogram to the left. - if (idx >= hist_freq_.size()) { - hist_freq_.resize(idx + 1, 0); + int32_t upper_idx = CalculateBinIndex(upper_bound, lower_bound_, bin_width_); + // If upper_idx >= hist_freq_.size(), then expand the histogram to the right. + if (upper_idx >= hist_freq_.size()) { + hist_freq_.resize(upper_idx + 1, 0); } - - return idx; -} - -int CalibrationStatisticsCollectorHistogram::GetHistogramIndex( - const float value) { - // Calculate index of histogram - int idx = (value - lower_bound_) / bin_width_; - - return ExpandHistogramIfNeeded(idx); + return std::make_pair(lower_idx, upper_idx); } } // namespace calibrator diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.h b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.h index cbb2c22c90863e..84f641a5ad0c92 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.h @@ -18,7 +18,9 @@ limitations under the License. #include #include #include +#include +#include "absl/types/span.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h" @@ -27,42 +29,29 @@ limitations under the License. namespace tensorflow { namespace calibrator { -using ::stablehlo::quantization::CalibrationOptions; class CalibrationStatisticsCollectorHistogram : public CalibrationStatisticsCollectorBase { public: - explicit CalibrationStatisticsCollectorHistogram( - const CalibrationOptions& calib_opts) { - ClearData(); - num_bins_ = calib_opts.calibration_parameters().initial_num_bins(); - } + explicit CalibrationStatisticsCollectorHistogram() { ClearData(); } void ClearData() override; - void Collect(const float* data, unsigned int N) override; + void Collect(float min, float max, + absl::Span histogram) override; std::optional GetStatistics() const override; private: - // Returns expanded histogram's index. If idx < 0, then expand the histogram - // to the left. If idx >= hist_freq_.size(), then expand the histogram to the - // right. - int ExpandHistogramIfNeeded(int idx); - - // Calculate the histogram index of value and if index of value is exceeds the - // range of histogram, then this function extends hist_freq_ and updates - // lower_bound_. This function returns the expanded histogram's index. - int GetHistogramIndex(float value); + // Expands the histogram so the lower_bound and upper_bound can fit in the + // histogram. Returns the indexes associated to those values. + std::pair ExpandHistogramIfNeeded(float lower_bound, + float upper_bound); // hist_freq_[i] saves frequency of range [bins[i], bins[i + 1]). // bins[i] = lower_bound_ + bin_width_ * i // bins[i + 1] = lower_bound_ + bin_width_ * (i + 1) - std::deque hist_freq_; - - // The number of bins when histogram is initialized. It can be increased - // because histogram is dynamically expanded by sample inputs. - int num_bins_; + std::deque hist_freq_; // Width of bin float bin_width_; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.cc index d549344fcc4a1f..50b7590a2db0b5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.cc @@ -15,9 +15,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.h" #include +#include #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" namespace tensorflow { @@ -33,18 +35,12 @@ void CalibrationStatisticsCollectorMinMax::ClearData() { min_max_statistics_.set_global_max(std::numeric_limits::lowest()); } -void CalibrationStatisticsCollectorMinMax::Collect(const float *data, - const unsigned int N) { - float input_min = min_max_statistics_.global_min(); - float input_max = min_max_statistics_.global_max(); - - for (int i = 0; i < N; ++i) { - input_min = std::min(input_min, data[i]); - input_max = std::max(input_max, data[i]); - } - - min_max_statistics_.set_global_min(input_min); - min_max_statistics_.set_global_max(input_max); +void CalibrationStatisticsCollectorMinMax::Collect( + const float min, const float max, absl::Span histogram) { + min_max_statistics_.set_global_min( + std::min(min_max_statistics_.global_min(), min)); + min_max_statistics_.set_global_max( + std::max(min_max_statistics_.global_max(), max)); } std::optional diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.h b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.h index c282bc29987755..8ee545e53f36b7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.h @@ -15,8 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_MIN_MAX_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_MIN_MAX_H_ +#include #include +#include "absl/types/span.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h" @@ -37,7 +39,8 @@ class CalibrationStatisticsCollectorMinMax void ClearData() override; - void Collect(const float *data, unsigned int N) override; + void Collect(float min, float max, + absl::Span histogram) override; std::optional GetStatistics() const override; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_test.cc index 3a87488649019b..5e291bec868537 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include +#include #include #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.h" #include "tensorflow/core/platform/test.h" @@ -26,19 +26,16 @@ namespace tensorflow { namespace calibrator { namespace { +using ::testing::ElementsAre; + TEST(CalibrationStatisticsCollectorTest, SimpleMinMax) { auto collector = CalibrationStatisticsCollectorMinMax(); - std::vector> collect_vec; - - collect_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - collect_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 10.0f}); - collect_vec.push_back({-5.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + collector.Collect( + /*min=*/1.0f, /*max=*/10.f, /*histogram=*/{}); + collector.Collect( + /*min=*/-5.0f, /*max=*/5.f, /*histogram=*/{}); - for (auto data_vec : collect_vec) { - collector.CalibrationStatisticsCollectorBase::Collect( - /*data_vec=*/data_vec); - } std::optional statistics = collector.GetStatistics(); EXPECT_TRUE(statistics.has_value()); @@ -49,42 +46,26 @@ TEST(CalibrationStatisticsCollectorTest, SimpleMinMax) { TEST(CalibrationStatisticsCollectorTest, SimpleAverageMinMax) { auto collector = CalibrationStatisticsCollectorAverageMinMax(); - std::vector> collect_vec; - - collect_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); // min=1.0f, max=5.0f - collect_vec.push_back( - {1.0f, 2.0f, 3.0f, 4.0f, 10.0f}); // min=1.0f, max=10.0f - collect_vec.push_back( - {-5.0f, 2.0f, 3.0f, 4.0f, 5.0f}); // min=-5.0f, max=5.0f + collector.Collect( + /*min=*/1.0f, /*max=*/10.f, /*histogram=*/{}); + collector.Collect( + /*min=*/-5.0f, /*max=*/5.f, /*histogram=*/{}); - for (auto data_vec : collect_vec) { - collector.CalibrationStatisticsCollectorBase::Collect( - /*data_vec=*/data_vec); - } std::optional statistics = collector.GetStatistics(); EXPECT_TRUE(statistics.has_value()); - // 1.0f + 1.0f - 5.0f - EXPECT_EQ(statistics.value().average_min_max_statistics().min_sum(), -3.0f); - // 5.0f + 10.0f + 5.0f - EXPECT_EQ(statistics.value().average_min_max_statistics().max_sum(), 20.0f); - // collect_vec.size() - EXPECT_EQ(statistics.value().average_min_max_statistics().num_samples(), 3); + EXPECT_EQ(statistics.value().average_min_max_statistics().min_sum(), -4.0f); + EXPECT_EQ(statistics.value().average_min_max_statistics().max_sum(), 15.0f); + EXPECT_EQ(statistics.value().average_min_max_statistics().num_samples(), 2); } TEST(CalibrationStatisticsCollectorTest, ClearDataAndGetResultsMinMax) { auto collector = CalibrationStatisticsCollectorMinMax(); - std::vector> collect_vec; - - collect_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - collect_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 10.0f}); - collect_vec.push_back({-5.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - - for (auto data_vec : collect_vec) { - collector.CalibrationStatisticsCollectorBase::Collect( - /*data_vec=*/data_vec); - } + collector.Collect( + /*min=*/1.0f, /*max=*/10.f, /*histogram=*/{}); + collector.Collect( + /*min=*/-5.0f, /*max=*/5.f, /*histogram=*/{}); std::optional statistics = collector.GetStatistics(); @@ -96,11 +77,10 @@ TEST(CalibrationStatisticsCollectorTest, ClearDataAndGetResultsMinMax) { statistics = collector.GetStatistics(); EXPECT_FALSE(statistics.has_value()); - collect_vec.pop_back(); // pop last element - for (auto data_vec : collect_vec) { - collector.CalibrationStatisticsCollectorBase::Collect( - /*data_vec=*/data_vec); - } + collector.Collect( + /*min=*/1.0f, /*max=*/10.f, /*histogram=*/{}); + collector.Collect( + /*min=*/2.0f, /*max=*/5.f, /*histogram=*/{}); statistics = collector.GetStatistics(); @@ -112,41 +92,213 @@ TEST(CalibrationStatisticsCollectorTest, ClearDataAndGetResultsMinMax) { TEST(CalibrationStatisticsCollectorTest, ClearDataAndGetResultsAverageMinMax) { auto collector = CalibrationStatisticsCollectorAverageMinMax(); - std::vector> collect_vec; - - collect_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - collect_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 20.0f}); - collect_vec.push_back({-5.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - - for (auto data_vec : collect_vec) { - collector.CalibrationStatisticsCollectorBase::Collect( - /*data_vec=*/data_vec); - } + collector.Collect( + /*min=*/1.0f, /*max=*/10.f, /*histogram=*/{}); + collector.Collect( + /*min=*/-5.0f, /*max=*/5.f, /*histogram=*/{}); std::optional statistics = collector.GetStatistics(); EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().average_min_max_statistics().min_sum(), -3.0f); - EXPECT_EQ(statistics.value().average_min_max_statistics().max_sum(), 30.0f); - EXPECT_EQ(statistics.value().average_min_max_statistics().num_samples(), 3); + EXPECT_EQ(statistics.value().average_min_max_statistics().min_sum(), -4.0f); + EXPECT_EQ(statistics.value().average_min_max_statistics().max_sum(), 15.0f); + EXPECT_EQ(statistics.value().average_min_max_statistics().num_samples(), 2); collector.ClearData(); statistics = collector.GetStatistics(); EXPECT_FALSE(statistics.has_value()); - collect_vec.pop_back(); // pop last element - for (auto data_vec : collect_vec) { - collector.CalibrationStatisticsCollectorBase::Collect( - /*data_vec=*/data_vec); - } + collector.Collect( + /*min=*/1.0f, /*max=*/10.f, /*histogram=*/{}); statistics = collector.GetStatistics(); EXPECT_TRUE(statistics.has_value()); - EXPECT_EQ(statistics.value().average_min_max_statistics().min_sum(), 2.0f); - EXPECT_EQ(statistics.value().average_min_max_statistics().max_sum(), 25.0f); - EXPECT_EQ(statistics.value().average_min_max_statistics().num_samples(), 2); + EXPECT_EQ(statistics.value().average_min_max_statistics().min_sum(), 1.0f); + EXPECT_EQ(statistics.value().average_min_max_statistics().max_sum(), 10.0f); + EXPECT_EQ(statistics.value().average_min_max_statistics().num_samples(), 1); } + +TEST(HistogramStatisticsCollectorTest, SingleBatchSimple) { + CalibrationOptions calib_opts; + calib_opts.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY); + auto collector = CalibrationStatisticsCollectorHistogram(); + + collector.Collect( + /*min=*/1.f, /*max=*/16.f, /*histogram=*/{1, 0, 3, 5, 7, 6, 5, 0}); + + std::optional statistics = collector.GetStatistics(); + EXPECT_TRUE(statistics.has_value()); + EXPECT_EQ(statistics.value().histogram_statistics().lower_bound(), 0.f); + EXPECT_EQ(statistics.value().histogram_statistics().bin_width(), 2.f); + // Trailing zeros should be removed. + EXPECT_THAT(statistics.value().histogram_statistics().hist_freq(), + ElementsAre(1, 0, 3, 5, 7, 6, 5)); +} + +TEST(HistogramStatisticsCollectorTest, AggregateSameBatchSize) { + CalibrationOptions calib_opts; + calib_opts.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY); + auto collector = CalibrationStatisticsCollectorHistogram(); + + collector.Collect( + /*min=*/1.f, /*max=*/16.f, /*histogram=*/{1, 0, 3, 5, 7, 6, 5, 1}); + + std::optional statistics = collector.GetStatistics(); + EXPECT_TRUE(statistics.has_value()); + EXPECT_EQ(statistics.value().histogram_statistics().lower_bound(), 0.f); + EXPECT_EQ(statistics.value().histogram_statistics().bin_width(), 2.f); + EXPECT_THAT(statistics.value().histogram_statistics().hist_freq(), + ElementsAre(1, 0, 3, 5, 7, 6, 5, 1)); + + collector.Collect( + /*min=*/-1.f, /*max=*/12.f, /*histogram=*/{1, 0, 1, 2, 2, 1, 1, 0}); + + statistics = collector.GetStatistics(); + EXPECT_TRUE(statistics.has_value()); + EXPECT_EQ(statistics.value().histogram_statistics().lower_bound(), -2.f); + EXPECT_EQ(statistics.value().histogram_statistics().bin_width(), 2.f); + EXPECT_THAT(statistics.value().histogram_statistics().hist_freq(), + ElementsAre(1, 1, 1, 5, 7, 8, 7, 5, 1)); +} + +TEST(HistogramStatisticsCollectorTest, AggregateSmallerBatchSizeExpandLeft) { + CalibrationOptions calib_opts; + calib_opts.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY); + auto collector = CalibrationStatisticsCollectorHistogram(); + + collector.Collect( + /*min=*/1.f, /*max=*/16.f, /*histogram=*/{1, 0, 3, 5, 7, 6, 5, 1}); + + std::optional statistics = collector.GetStatistics(); + EXPECT_TRUE(statistics.has_value()); + EXPECT_EQ(statistics.value().histogram_statistics().lower_bound(), 0.f); + EXPECT_EQ(statistics.value().histogram_statistics().bin_width(), 2.f); + EXPECT_THAT(statistics.value().histogram_statistics().hist_freq(), + ElementsAre(1, 0, 3, 5, 7, 6, 5, 1)); + + collector.Collect( + /*min=*/-1.f, /*max=*/5.f, /*histogram=*/{1, 0, 1, 2, 2, 1, 1, 0}); + + statistics = collector.GetStatistics(); + EXPECT_TRUE(statistics.has_value()); + EXPECT_EQ(statistics.value().histogram_statistics().lower_bound(), -2.f); + EXPECT_EQ(statistics.value().histogram_statistics().bin_width(), 2.f); + EXPECT_THAT(statistics.value().histogram_statistics().hist_freq(), + ElementsAre(1, 2, 4, 5, 5, 7, 6, 5, 1)); +} + +TEST(HistogramStatisticsCollectorTest, AggregateSmallerBatchSizeExpandRight) { + CalibrationOptions calib_opts; + calib_opts.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY); + auto collector = CalibrationStatisticsCollectorHistogram(); + + collector.Collect( + /*min=*/1.f, /*max=*/16.f, /*histogram=*/{1, 0, 3, 5, 7, 6, 5, 1}); + + std::optional statistics = collector.GetStatistics(); + EXPECT_TRUE(statistics.has_value()); + EXPECT_EQ(statistics.value().histogram_statistics().lower_bound(), 0.f); + EXPECT_EQ(statistics.value().histogram_statistics().bin_width(), 2.f); + EXPECT_THAT(statistics.value().histogram_statistics().hist_freq(), + ElementsAre(1, 0, 3, 5, 7, 6, 5, 1)); + + collector.Collect( + /*min=*/13.f, /*max=*/19.f, /*histogram=*/{1, 0, 1, 2, 2, 1, 1, 0}); + + statistics = collector.GetStatistics(); + EXPECT_TRUE(statistics.has_value()); + EXPECT_EQ(statistics.value().histogram_statistics().lower_bound(), 0.f); + EXPECT_EQ(statistics.value().histogram_statistics().bin_width(), 2.f); + EXPECT_THAT(statistics.value().histogram_statistics().hist_freq(), + ElementsAre(1, 0, 3, 5, 7, 6, 6, 2, 4, 2)); +} + +TEST(HistogramStatisticsCollectorTest, AggregateTinyBinWidth) { + CalibrationOptions calib_opts; + calib_opts.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY); + auto collector = CalibrationStatisticsCollectorHistogram(); + + collector.Collect( + /*min=*/1.f, /*max=*/16.f, /*histogram=*/{1, 0, 3, 5, 7, 6, 5, 1}); + + std::optional statistics = collector.GetStatistics(); + EXPECT_TRUE(statistics.has_value()); + EXPECT_EQ(statistics.value().histogram_statistics().lower_bound(), 0.f); + EXPECT_EQ(statistics.value().histogram_statistics().bin_width(), 2.f); + EXPECT_THAT(statistics.value().histogram_statistics().hist_freq(), + ElementsAre(1, 0, 3, 5, 7, 6, 5, 1)); + + collector.Collect( + /*min=*/-1.f, /*max=*/-0.99998f, /*histogram=*/{1, 0, 1, 2, 2, 1, 1, 0}); + + statistics = collector.GetStatistics(); + EXPECT_TRUE(statistics.has_value()); + EXPECT_EQ(statistics.value().histogram_statistics().lower_bound(), -2.f); + EXPECT_EQ(statistics.value().histogram_statistics().bin_width(), 2.f); + EXPECT_THAT(statistics.value().histogram_statistics().hist_freq(), + ElementsAre(8, 1, 0, 3, 5, 7, 6, 5, 1)); +} + +TEST(HistogramStatisticsCollectorTest, AggregateLargerBatchSizeExpandLeft) { + CalibrationOptions calib_opts; + calib_opts.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY); + auto collector = CalibrationStatisticsCollectorHistogram(); + + collector.Collect( + /*min=*/1.f, /*max=*/16.f, /*histogram=*/{1, 0, 3, 5, 7, 6, 5, 1}); + + std::optional statistics = collector.GetStatistics(); + EXPECT_TRUE(statistics.has_value()); + EXPECT_EQ(statistics.value().histogram_statistics().lower_bound(), 0.f); + EXPECT_EQ(statistics.value().histogram_statistics().bin_width(), 2.f); + EXPECT_THAT(statistics.value().histogram_statistics().hist_freq(), + ElementsAre(1, 0, 3, 5, 7, 6, 5, 1)); + + collector.Collect( + /*min=*/-5.f, /*max=*/5.f, /*histogram=*/{1, 2, 2, 1}); + + statistics = collector.GetStatistics(); + EXPECT_TRUE(statistics.has_value()); + EXPECT_EQ(statistics.value().histogram_statistics().lower_bound(), -8.f); + EXPECT_EQ(statistics.value().histogram_statistics().bin_width(), 2.f); + EXPECT_THAT(statistics.value().histogram_statistics().hist_freq(), + ElementsAre(0.5, 0.5, 1, 1, 2, 1, 3.5, 5.5, 7, 6, 5, 1)); +} + +TEST(HistogramStatisticsCollectorTest, AggregateLargerBatchSizeExpandRight) { + CalibrationOptions calib_opts; + calib_opts.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY); + auto collector = CalibrationStatisticsCollectorHistogram(); + + collector.Collect( + /*min=*/1.f, /*max=*/16.f, /*histogram=*/{1, 0, 3, 5, 7, 6, 5, 1}); + + std::optional statistics = collector.GetStatistics(); + EXPECT_TRUE(statistics.has_value()); + EXPECT_EQ(statistics.value().histogram_statistics().lower_bound(), 0.f); + EXPECT_EQ(statistics.value().histogram_statistics().bin_width(), 2.f); + EXPECT_THAT(statistics.value().histogram_statistics().hist_freq(), + ElementsAre(1, 0, 3, 5, 7, 6, 5, 1)); + + collector.Collect( + /*min=*/10.f, /*max=*/21.f, /*histogram=*/{1, 2, 2, 1}); + + statistics = collector.GetStatistics(); + EXPECT_TRUE(statistics.has_value()); + EXPECT_EQ(statistics.value().histogram_statistics().lower_bound(), 0.f); + EXPECT_EQ(statistics.value().histogram_statistics().bin_width(), 2.f); + EXPECT_THAT(statistics.value().histogram_statistics().hist_freq(), + ElementsAre(1, 0, 3, 5, 7.5, 6.5, 6, 2, 1, 1, 0.5, 0.5)); +} + } // namespace } // namespace calibrator } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc new file mode 100644 index 00000000000000..8061ad3fe2d444 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc @@ -0,0 +1,187 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tsl/platform/file_system.h" + +namespace tensorflow { +namespace { + +using ::stablehlo::quantization::CalibrationOptions; +using CalibrationMethod = + ::stablehlo::quantization::CalibrationOptions_CalibrationMethod; +using ::tensorflow::calibrator::CalibrationStatistics; +using ::tensorflow::calibrator::CalibrationStatisticsCollectorAverageMinMax; +using ::tensorflow::calibrator::CalibrationStatisticsCollectorBase; +using ::tensorflow::calibrator::CalibrationStatisticsCollectorHistogram; +using ::tensorflow::calibrator::CalibrationStatisticsCollectorMinMax; +using ::tensorflow::calibrator::CalibrationStatisticsMap; + +} // namespace + +REGISTER_OP("CalibrationStatisticsSaver") + .Input("args: Tin") + .Attr("Tin: list(type) >= 0") + .Attr("ids: list(string) >= 1") + .Attr("calibration_methods: list(int) >= 1") + .Attr("output_file_path: string") + .SetIsStateful() + .Doc(R"doc( +Aggregates and saves the calibration statistics data. + +This op collects outputs of multiples CustomAggregator ops, which includes +`min`, `max` and `histogram`. Then it aggregates them according to the +calibration method and save the result to the given file path as a binary +proto file.)doc"); + +class CalibrationStatisticsSaverOp : public OpKernel { + public: + explicit CalibrationStatisticsSaverOp( + absl::Nonnull context) + : OpKernel(context) { + std::string output_file_path; + OP_REQUIRES_OK(context, + context->GetAttr("output_file_path", &output_file_path)); + OP_REQUIRES_OK(context, context->env()->NewWritableFile(output_file_path, + &output_file_)); + + OP_REQUIRES_OK(context, context->GetAttr("ids", &ids_)); + OP_REQUIRES_OK(context, context->GetAttr("calibration_methods", + &calibration_methods_)); + OP_REQUIRES( + context, ids_.size() == calibration_methods_.size(), + absl::AbortedError( + "The `ids` and `calibration_methods` must have the same size.")); + + // Check the number and type of inputs. + OP_REQUIRES(context, context->num_inputs() == ids_.size() * 3, + absl::AbortedError("The number of inputs must be three times " + "the size of the `ids` list.")); + for (int i = 0; i < ids_.size(); ++i) { + OP_REQUIRES(context, context->input_type(i * 3) == DT_FLOAT, + absl::AbortedError("The input `min` must have float type.")); + OP_REQUIRES(context, context->input_type(i * 3 + 1) == DT_FLOAT, + absl::AbortedError("The input `max` must have float type.")); + OP_REQUIRES( + context, context->input_type(i * 3 + 2) == DT_INT64, + absl::AbortedError("The input `histogram` must have int64 type.")); + } + } + + ~CalibrationStatisticsSaverOp() override { + // Save to file during destruction so we only save it once. + // TODO - b/335044516 : Find a way to flush outside of the destructor. + CalibrationStatisticsMap statistics_map; + for (const auto& [id, collector] : id_to_collector_) { + std::optional statistics = + collector->GetStatistics(); + if (!statistics.has_value()) continue; + + statistics_map.mutable_statistics()->emplace(id, std::move(*statistics)); + } + + if (auto status = output_file_->Append(statistics_map.SerializeAsString()); + !status.ok()) { + LOG(ERROR) << "Failed to write calibration statistics: " + << status.message(); + } + if (auto status = output_file_->Close(); !status.ok()) { + LOG(ERROR) << "Failed to close calibration statistics file: " + << status.message(); + } + } + + void Compute(absl::Nonnull context) override { + for (int idx = 0; idx < ids_.size(); ++idx) { + AssignIfNotExists( + ids_[idx], static_cast(calibration_methods_[idx])); + + const Tensor& min_tensor = context->input(3 * idx); + const Tensor& max_tensor = context->input(3 * idx + 1); + const Tensor& histogram_tensor = context->input(3 * idx + 2); + + const float min_value = min_tensor.scalar()(); + const float max_value = max_tensor.scalar()(); + auto histogram_flat = histogram_tensor.flat(); + absl::Span histogram_data = + absl::MakeSpan(histogram_flat.data(), histogram_flat.size()); + id_to_collector_[ids_[idx]]->Collect(min_value, max_value, + histogram_data); + } + } + + private: + // The path to save calibration statistics data. + std::unique_ptr output_file_; + // The id and calibration method of preceding CustomAggregator ops. + std::vector ids_; + std::vector calibration_methods_; + // Map from id to its collector instance. + absl::flat_hash_map> + id_to_collector_; + + void AssignIfNotExists(absl::string_view id, + const CalibrationMethod calibration_method) { + std::unique_ptr& collector = + id_to_collector_[id]; + + if (collector != nullptr) return; + + switch (calibration_method) { + case CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX: + collector = + std::make_unique(); + break; + case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE: + case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE: + case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC: + case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY: + collector = std::make_unique(); + break; + case CalibrationOptions::CALIBRATION_METHOD_MIN_MAX: + default: + collector = std::make_unique(); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("CalibrationStatisticsSaver").Device(DEVICE_CPU), + CalibrationStatisticsSaverOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op_test.cc new file mode 100644 index 00000000000000..8335722cdea929 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op_test.cc @@ -0,0 +1,291 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include +#include +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/test.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/status_matchers.h" + +namespace tensorflow { +namespace { + +using ::stablehlo::quantization::CalibrationOptions; +using ::tensorflow::calibrator::CalibrationStatistics; +using ::tensorflow::calibrator::CalibrationStatisticsMap; +using ::testing::Contains; +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::Key; +using ::testing::SizeIs; +using ::tsl::testing::StatusIs; + +class CalibrationStatisticsSaverTest : public OpsTestBase {}; + +TEST_F(CalibrationStatisticsSaverTest, MissingOutputPath) { + std::vector ids{"1"}; + std::vector calibration_methods{ + CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX}; + + std::vector inputs; + inputs.emplace_back("min", 0, DT_FLOAT); + inputs.emplace_back("max", 0, DT_FLOAT); + + TF_CHECK_OK(NodeDefBuilder("op", "CalibrationStatisticsSaver") + .Input(inputs) + .Attr("ids", ids) + .Attr("calibration_methods", calibration_methods) + .Finalize(node_def())); + ASSERT_THAT(InitOp(), + StatusIs(tsl::error::INVALID_ARGUMENT, + HasSubstr("NodeDef missing attr 'output_file_path'"))); +} + +TEST_F(CalibrationStatisticsSaverTest, WrongNumInputs) { + std::vector ids{"1"}; + std::vector calibration_methods{ + CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX}; + + std::vector inputs; + inputs.emplace_back("min", 0, DT_FLOAT); + inputs.emplace_back("max", 0, DT_FLOAT); + + TF_CHECK_OK(NodeDefBuilder("op", "CalibrationStatisticsSaver") + .Input(inputs) + .Attr("ids", ids) + .Attr("calibration_methods", calibration_methods) + .Attr("output_file_path", "/tmp/statistics.pbtxt") + .Finalize(node_def())); + ASSERT_THAT(InitOp(), + StatusIs(tsl::error::ABORTED, + HasSubstr("The number of inputs must be three times " + "the size of the `ids` list."))); +} + +TEST_F(CalibrationStatisticsSaverTest, WrongInputTypes) { + std::vector ids{"1"}; + std::vector calibration_methods{ + CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX}; + + std::vector inputs; + inputs.emplace_back("min", 0, DT_FLOAT); + inputs.emplace_back("max", 0, DT_FLOAT); + inputs.emplace_back("histogram", 0, DT_FLOAT); + + TF_CHECK_OK(NodeDefBuilder("op", "CalibrationStatisticsSaver") + .Input(inputs) + .Attr("ids", ids) + .Attr("calibration_methods", calibration_methods) + .Attr("output_file_path", "/tmp/statistics.pbtxt") + .Finalize(node_def())); + ASSERT_THAT( + InitOp(), + StatusIs(tsl::error::ABORTED, + HasSubstr("The input `histogram` must have int64 type"))); +} + +TEST_F(CalibrationStatisticsSaverTest, SimpleMinMax) { + std::vector ids{"1"}; + std::vector calibration_methods{ + CalibrationOptions::CALIBRATION_METHOD_MIN_MAX}; + + std::vector inputs; + inputs.emplace_back("min", 0, DT_FLOAT); + inputs.emplace_back("max", 0, DT_FLOAT); + inputs.emplace_back("histogram", 0, DT_INT64); + + const std::string dir = testing::TmpDir(); + const std::string output_file_path = io::JoinPath(dir, "statistics.pbtxt"); + + TF_CHECK_OK(NodeDefBuilder("op", "CalibrationStatisticsSaver") + .Input(inputs) + .Attr("ids", ids) + .Attr("calibration_methods", calibration_methods) + .Attr("output_file_path", output_file_path) + .Finalize(node_def())); + TF_CHECK_OK(InitOp()); + + AddInputFromArray(TensorShape({}), {1.f}); + AddInputFromArray(TensorShape({}), {5.f}); + AddInputFromArray(TensorShape({0}), {}); + + TF_CHECK_OK(RunOpKernel()); + kernel_.reset(); + + CalibrationStatisticsMap statistics_map; + TF_CHECK_OK( + ReadBinaryProto(Env::Default(), output_file_path, &statistics_map)); + ASSERT_THAT(statistics_map.statistics(), SizeIs(1)); + ASSERT_THAT(statistics_map.statistics(), ElementsAre(Key("1"))); + + const CalibrationStatistics& stats = statistics_map.statistics().at("1"); + ASSERT_TRUE(stats.has_min_max_statistics()); + EXPECT_FLOAT_EQ(stats.min_max_statistics().global_min(), 1.f); + EXPECT_FLOAT_EQ(stats.min_max_statistics().global_max(), 5.f); +} + +TEST_F(CalibrationStatisticsSaverTest, SimpleAverageMinMax) { + std::vector ids{"1"}; + std::vector calibration_methods{ + CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX}; + + std::vector inputs; + inputs.emplace_back("min", 0, DT_FLOAT); + inputs.emplace_back("max", 0, DT_FLOAT); + inputs.emplace_back("histogram", 0, DT_INT64); + + const std::string dir = testing::TmpDir(); + const std::string output_file_path = io::JoinPath(dir, "statistics.pbtxt"); + + TF_CHECK_OK(NodeDefBuilder("op", "CalibrationStatisticsSaver") + .Input(inputs) + .Attr("ids", ids) + .Attr("calibration_methods", calibration_methods) + .Attr("output_file_path", output_file_path) + .Finalize(node_def())); + TF_CHECK_OK(InitOp()); + + AddInputFromArray(TensorShape({}), {1.f}); + AddInputFromArray(TensorShape({}), {5.f}); + AddInputFromArray(TensorShape({0}), {}); + + TF_CHECK_OK(RunOpKernel()); + kernel_.reset(); + + CalibrationStatisticsMap statistics_map; + TF_CHECK_OK( + ReadBinaryProto(Env::Default(), output_file_path, &statistics_map)); + ASSERT_THAT(statistics_map.statistics(), SizeIs(1)); + ASSERT_THAT(statistics_map.statistics(), ElementsAre(Key("1"))); + + const CalibrationStatistics& stats = statistics_map.statistics().at("1"); + ASSERT_TRUE(stats.has_average_min_max_statistics()); + EXPECT_FLOAT_EQ(stats.average_min_max_statistics().min_sum(), 1.f); + EXPECT_FLOAT_EQ(stats.average_min_max_statistics().max_sum(), 5.f); + EXPECT_EQ(stats.average_min_max_statistics().num_samples(), 1); +} + +TEST_F(CalibrationStatisticsSaverTest, SimpleHistogram) { + std::vector ids{"1"}; + std::vector calibration_methods{ + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE}; + + std::vector inputs; + inputs.emplace_back("min", 0, DT_FLOAT); + inputs.emplace_back("max", 0, DT_FLOAT); + inputs.emplace_back("histogram", 0, DT_INT64); + + const std::string dir = testing::TmpDir(); + const std::string output_file_path = io::JoinPath(dir, "statistics.pbtxt"); + + TF_CHECK_OK(NodeDefBuilder("op", "CalibrationStatisticsSaver") + .Input(inputs) + .Attr("ids", ids) + .Attr("calibration_methods", calibration_methods) + .Attr("output_file_path", output_file_path) + .Finalize(node_def())); + TF_CHECK_OK(InitOp()); + + AddInputFromArray(TensorShape({}), {1.f}); + AddInputFromArray(TensorShape({}), {5.f}); + AddInputFromArray(TensorShape({8}), {1, 4, 6, 7, 3, 2, 1, 0}); + + TF_CHECK_OK(RunOpKernel()); + kernel_.reset(); + + CalibrationStatisticsMap statistics_map; + TF_CHECK_OK( + ReadBinaryProto(Env::Default(), output_file_path, &statistics_map)); + ASSERT_THAT(statistics_map.statistics(), SizeIs(1)); + ASSERT_THAT(statistics_map.statistics(), ElementsAre(Key("1"))); + + const CalibrationStatistics& stats = statistics_map.statistics().at("1"); + ASSERT_TRUE(stats.has_histogram_statistics()); + EXPECT_FLOAT_EQ(stats.histogram_statistics().bin_width(), 0.5f); + EXPECT_FLOAT_EQ(stats.histogram_statistics().lower_bound(), 1.f); + EXPECT_THAT(stats.histogram_statistics().hist_freq(), + ElementsAre(1, 4, 6, 7, 3, 2, 1)); +} + +TEST_F(CalibrationStatisticsSaverTest, MultipleStats) { + std::vector ids{"1", "2"}; + std::vector calibration_methods{ + CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX, + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE}; + + std::vector inputs; + inputs.emplace_back("min", 0, DT_FLOAT); + inputs.emplace_back("max", 0, DT_FLOAT); + inputs.emplace_back("histogram", 0, DT_INT64); + inputs.emplace_back("min", 0, DT_FLOAT); + inputs.emplace_back("max", 0, DT_FLOAT); + inputs.emplace_back("histogram", 0, DT_INT64); + + const std::string dir = testing::TmpDir(); + const std::string output_file_path = io::JoinPath(dir, "statistics.pbtxt"); + + TF_CHECK_OK(NodeDefBuilder("op", "CalibrationStatisticsSaver") + .Input(inputs) + .Attr("ids", ids) + .Attr("calibration_methods", calibration_methods) + .Attr("output_file_path", output_file_path) + .Finalize(node_def())); + TF_CHECK_OK(InitOp()); + + AddInputFromArray(TensorShape({}), {1.f}); + AddInputFromArray(TensorShape({}), {5.f}); + AddInputFromArray(TensorShape({0}), {}); + AddInputFromArray(TensorShape({}), {1.f}); + AddInputFromArray(TensorShape({}), {5.f}); + AddInputFromArray(TensorShape({8}), {1, 4, 6, 7, 3, 2, 1, 0}); + + TF_CHECK_OK(RunOpKernel()); + kernel_.reset(); + + CalibrationStatisticsMap statistics_map; + TF_CHECK_OK( + ReadBinaryProto(Env::Default(), output_file_path, &statistics_map)); + ASSERT_THAT(statistics_map.statistics(), SizeIs(2)); + ASSERT_THAT(statistics_map.statistics(), Contains(Key("1"))); + ASSERT_THAT(statistics_map.statistics(), Contains(Key("2"))); + + const CalibrationStatistics& stats_1 = statistics_map.statistics().at("1"); + ASSERT_TRUE(stats_1.has_average_min_max_statistics()); + EXPECT_FLOAT_EQ(stats_1.average_min_max_statistics().min_sum(), 1.f); + EXPECT_FLOAT_EQ(stats_1.average_min_max_statistics().max_sum(), 5.f); + EXPECT_EQ(stats_1.average_min_max_statistics().num_samples(), 1); + + const CalibrationStatistics& stats_2 = statistics_map.statistics().at("2"); + ASSERT_TRUE(stats_2.has_histogram_statistics()); + EXPECT_FLOAT_EQ(stats_2.histogram_statistics().bin_width(), 0.5f); + EXPECT_FLOAT_EQ(stats_2.histogram_statistics().lower_bound(), 1.f); + EXPECT_THAT(stats_2.histogram_statistics().hist_freq(), + ElementsAre(1, 4, 6, 7, 3, 2, 1)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc index 7fe3b34c8137d1..74575b761737a3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc @@ -60,40 +60,27 @@ void CalibratorSingleton::ClearData(absl::string_view id) { instance.id_to_collector_[id_str].reset(nullptr); } -void CalibratorSingleton::Report(absl::string_view id, - absl::Span data_span, +void CalibratorSingleton::Report(absl::string_view id, const Tensor& min_tensor, + const Tensor& max_tensor, + const Tensor& histogram_tensor, const CalibrationOptions& calib_opts) { - absl::MutexLock lock(&lock_); - - CalibratorSingleton& instance = GetInstance(); - - const std::string id_str{id}; - AssignIfNotExists(id_str, calib_opts); - instance.id_to_collector_[id_str]->Collect(data_span); + const float min_value = min_tensor.scalar()(); + const float max_value = max_tensor.scalar()(); + auto histogram_flat = histogram_tensor.flat(); + absl::Span histogram_data = + absl::MakeSpan(histogram_flat.data(), histogram_flat.size()); + Report(id, min_value, max_value, histogram_data, calib_opts); } -void CalibratorSingleton::Report(absl::string_view id, - const std::vector& data_vec, +void CalibratorSingleton::Report(absl::string_view id, float min, float max, + absl::Span histogram, const CalibrationOptions& calib_opts) { absl::MutexLock lock(&lock_); CalibratorSingleton& instance = GetInstance(); - const std::string id_str{id}; AssignIfNotExists(id_str, calib_opts); - instance.id_to_collector_[id_str]->Collect(data_vec); -} - -void CalibratorSingleton::Report(absl::string_view id, - const Tensor& data_tensor, - const CalibrationOptions& calib_opts) { - absl::MutexLock lock(&lock_); - - CalibratorSingleton& instance = GetInstance(); - - const std::string id_str{id}; - AssignIfNotExists(id_str, calib_opts); - instance.id_to_collector_[id_str]->Collect(data_tensor); + instance.id_to_collector_[id_str]->Collect(min, max, histogram); } std::optional CalibratorSingleton::GetStatistics( @@ -111,37 +98,27 @@ std::optional CalibratorSingleton::GetStatistics( return instance.id_to_collector_[id_str]->GetStatistics(); } -int64_t CalibratorSingleton::IssueNewId() { - CalibratorSingleton& instance = GetInstance(); - return instance.next_id_++; -} - void CalibratorSingleton::AssignIfNotExists( std::string id_str, const CalibrationOptions& calib_opts) { CalibratorSingleton& instance = GetInstance(); - - if (!instance.id_to_collector_[id_str]) { - CalibrationOptions::CalibrationMethod calib_method = - calib_opts.calibration_method(); - - switch (calib_method) { - case CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX: - instance.id_to_collector_[id_str] = - std::make_unique(); - break; - case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE: - case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE: - case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC: - case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY: - instance.id_to_collector_[id_str] = - std::make_unique( - calib_opts); - break; - case CalibrationOptions::CALIBRATION_METHOD_MIN_MAX: - default: - instance.id_to_collector_[id_str] = - std::make_unique(); - } + if (instance.id_to_collector_[id_str]) return; + + switch (calib_opts.calibration_method()) { + case CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX: + instance.id_to_collector_[id_str] = + std::make_unique(); + break; + case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE: + case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE: + case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC: + case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY: + instance.id_to_collector_[id_str] = + std::make_unique(); + break; + case CalibrationOptions::CALIBRATION_METHOD_MIN_MAX: + default: + instance.id_to_collector_[id_str] = + std::make_unique(); } } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h index d909dcecf76a66..8a6aee81ee9cbd 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h @@ -47,42 +47,27 @@ class CalibratorSingleton { // Clears the collected data of the given node id. static void ClearData(absl::string_view id); - // Reports data to singleton using float vector. - // Only calculates the required statistics from CalibrationMethod based - // on CalibrationOptions. - static void Report(absl::string_view id, const std::vector& data_vec, + // Reports data to the singleton. Only calculates the required statistics + // based on CalibrationOptions. + static void Report(absl::string_view id, const Tensor& min_tensor, + const Tensor& max_tensor, const Tensor& histogram_tensor, const CalibrationOptions& calib_opts); - // Reports data to singleton using absl::Span - // Only calculates the required statistics from CalibrationMethod based - // on CalibrationOptions. - static void Report(absl::string_view id, absl::Span data_span, - const CalibrationOptions& calib_opts); - - // Reports data to singleton using absl::Span - // Only calculates the required statistics from CalibrationMethod based - // on CalibrationOptions. - static void Report(absl::string_view id, const Tensor& data_tensor, + // Same as above but accepts primitive input types. + static void Report(absl::string_view id, float min, float max, + absl::Span histogram, const CalibrationOptions& calib_opts); // Returns the calibration statistics of the given id. static std::optional GetStatistics( absl::string_view id); - // Issues a new node ID that uniquely identifies a set of calibration - // statistics. - static int64_t IssueNewId(); - private: static CalibratorSingleton& GetInstance(); static absl::Mutex lock_; static void AssignIfNotExists(std::string id_str, const CalibrationOptions& calib_opts); - // Indicates the next id for a set of calibration statistics. For every new ID - // issued this will be incremented atomically. - std::atomic next_id_{0}; - absl::flat_hash_map> id_to_collector_; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc index d7652c1b6806c4..ca338b58c5909d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc @@ -30,16 +30,12 @@ namespace { using ::stablehlo::quantization::CalibrationOptions; TEST(CalibratorSingletonTest, SimpleMinMax) { - std::vector> report_vec; CalibrationOptions calib_opts; calib_opts.set_calibration_method( CalibrationOptions::CALIBRATION_METHOD_MIN_MAX); - report_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - report_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 10.0f}); - report_vec.push_back({-5.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - - CalibratorSingleton::Report(/*id=*/"1", /*data_vec=*/report_vec[0], + CalibratorSingleton::Report(/*id=*/"1", /*min=*/1.0f, /*max=*/5.0f, + /*histogram=*/{}, /*calib_opts=*/calib_opts); std::optional statistics = CalibratorSingleton::GetStatistics(/*id=*/"1"); @@ -48,7 +44,8 @@ TEST(CalibratorSingletonTest, SimpleMinMax) { EXPECT_EQ(statistics.value().min_max_statistics().global_min(), 1.0f); EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 5.0f); - CalibratorSingleton::Report(/*id=*/"1", /*data_vec=*/report_vec[1], + CalibratorSingleton::Report(/*id=*/"1", /*min=*/1.0f, /*max=*/10.0f, + /*histogram=*/{}, /*calib_opts=*/calib_opts); statistics = CalibratorSingleton::GetStatistics(/*id=*/"1"); @@ -56,7 +53,8 @@ TEST(CalibratorSingletonTest, SimpleMinMax) { EXPECT_EQ(statistics.value().min_max_statistics().global_min(), 1.0f); EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 10.0f); - CalibratorSingleton::Report(/*id=*/"1", /*data_vec=*/report_vec[2], + CalibratorSingleton::Report(/*id=*/"1", /*min=*/-5.0f, /*max=*/5.0f, + /*histogram=*/{}, /*calib_opts=*/calib_opts); statistics = CalibratorSingleton::GetStatistics(/*id=*/"1"); @@ -66,16 +64,12 @@ TEST(CalibratorSingletonTest, SimpleMinMax) { } TEST(CalibratorSingletonTest, DifferentSessions) { - std::vector> report_vec; CalibrationOptions calib_opts; calib_opts.set_calibration_method( CalibrationOptions::CALIBRATION_METHOD_MIN_MAX); - report_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - report_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 10.0f}); - report_vec.push_back({-5.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - - CalibratorSingleton::Report(/*id=*/"2", /*data_vec=*/report_vec[0], + CalibratorSingleton::Report(/*id=*/"2", /*min=*/1.0f, /*max=*/5.0f, + /*histogram=*/{}, /*calib_opts=*/calib_opts); std::optional statistics = CalibratorSingleton::GetStatistics(/*id=*/"2"); @@ -84,7 +78,8 @@ TEST(CalibratorSingletonTest, DifferentSessions) { EXPECT_EQ(statistics.value().min_max_statistics().global_min(), 1.0f); EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 5.0f); - CalibratorSingleton::Report(/*id=*/"2", /*data_vec=*/report_vec[1], + CalibratorSingleton::Report(/*id=*/"2", /*min=*/1.0f, /*max=*/10.0f, + /*histogram=*/{}, /*calib_opts=*/calib_opts); statistics = CalibratorSingleton::GetStatistics(/*id=*/"2"); @@ -92,7 +87,8 @@ TEST(CalibratorSingletonTest, DifferentSessions) { EXPECT_EQ(statistics.value().min_max_statistics().global_min(), 1.0f); EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 10.0f); - CalibratorSingleton::Report(/*id=*/"3", /*data_vec=*/report_vec[2], + CalibratorSingleton::Report(/*id=*/"3", /*min=*/-5.0f, /*max=*/5.0f, + /*histogram=*/{}, /*calib_opts=*/calib_opts); statistics = CalibratorSingleton::GetStatistics(/*id=*/"3"); @@ -110,7 +106,8 @@ TEST(CalibratorSingletonTest, ClearAndGetEmptyResult) { report_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); report_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 10.0f}); - CalibratorSingleton::Report(/*id=*/"4", /*data_vec=*/report_vec[0], + CalibratorSingleton::Report(/*id=*/"4", /*min=*/1.0f, /*max=*/5.0f, + /*histogram=*/{}, /*calib_opts=*/calib_opts); std::optional statistics = CalibratorSingleton::GetStatistics(/*id=*/"4"); @@ -126,16 +123,12 @@ TEST(CalibratorSingletonTest, ClearAndGetEmptyResult) { } TEST(CalibratorSingletonTest, ClearDataAndGetResults) { - std::vector> report_vec; CalibrationOptions calib_opts; calib_opts.set_calibration_method( CalibrationOptions::CALIBRATION_METHOD_MIN_MAX); - report_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - report_vec.push_back({1.0f, 2.0f, 3.0f, 4.0f, 10.0f}); - report_vec.push_back({-5.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - - CalibratorSingleton::Report(/*id=*/"5", /*data_vec=*/report_vec[0], + CalibratorSingleton::Report(/*id=*/"5", /*min=*/1.0f, /*max=*/5.0f, + /*histogram=*/{}, /*calib_opts=*/calib_opts); std::optional statistics = CalibratorSingleton::GetStatistics(/*id=*/"5"); @@ -144,7 +137,8 @@ TEST(CalibratorSingletonTest, ClearDataAndGetResults) { EXPECT_EQ(statistics.value().min_max_statistics().global_min(), 1.0f); EXPECT_EQ(statistics.value().min_max_statistics().global_max(), 5.0f); - CalibratorSingleton::Report(/*id=*/"6", /*data_vec=*/report_vec[1], + CalibratorSingleton::Report(/*id=*/"6", /*min=*/1.0f, /*max=*/10.0f, + /*histogram=*/{}, /*calib_opts=*/calib_opts); statistics = CalibratorSingleton::GetStatistics(/*id=*/"6"); @@ -157,7 +151,8 @@ TEST(CalibratorSingletonTest, ClearDataAndGetResults) { EXPECT_FALSE(statistics.has_value()); - CalibratorSingleton::Report(/*id=*/"6", /*data_vec=*/report_vec[1], + CalibratorSingleton::Report(/*id=*/"6", /*min=*/1.0f, /*max=*/10.0f, + /*histogram=*/{}, /*calib_opts=*/calib_opts); statistics = CalibratorSingleton::GetStatistics(/*id=*/"6"); @@ -167,16 +162,12 @@ TEST(CalibratorSingletonTest, ClearDataAndGetResults) { } TEST(CalibratorSingletonTest, SimpleAverageMinMax) { - std::vector> report_vec; CalibrationOptions calib_opts; calib_opts.set_calibration_method( CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX); - report_vec.push_back({-10.0f, 2.0f, 3.0f, 4.0f, 30.0f}); - report_vec.push_back({-20.0f, 2.0f, 3.0f, 4.0f, 60.0f}); - report_vec.push_back({-30.0f, 2.0f, 3.0f, 4.0f, 90.0f}); - - CalibratorSingleton::Report(/*id=*/"7", /*data_vec=*/report_vec[0], + CalibratorSingleton::Report(/*id=*/"7", /*min=*/-10.0f, /*max=*/30.0f, + /*histogram=*/{}, /*calib_opts=*/calib_opts); std::optional statistics = CalibratorSingleton::GetStatistics(/*id=*/"7"); @@ -186,7 +177,8 @@ TEST(CalibratorSingletonTest, SimpleAverageMinMax) { EXPECT_EQ(statistics.value().average_min_max_statistics().max_sum(), 30.0f); EXPECT_EQ(statistics.value().average_min_max_statistics().num_samples(), 1); - CalibratorSingleton::Report(/*id=*/"7", /*data_vec=*/report_vec[1], + CalibratorSingleton::Report(/*id=*/"7", /*min=*/-20.0f, /*max=*/60.0f, + /*histogram=*/{}, /*calib_opts=*/calib_opts); statistics = CalibratorSingleton::GetStatistics(/*id=*/"7"); @@ -195,7 +187,8 @@ TEST(CalibratorSingletonTest, SimpleAverageMinMax) { EXPECT_EQ(statistics.value().average_min_max_statistics().max_sum(), 90.0f); EXPECT_EQ(statistics.value().average_min_max_statistics().num_samples(), 2); - CalibratorSingleton::Report(/*id=*/"7", /*data_vec=*/report_vec[2], + CalibratorSingleton::Report(/*id=*/"7", /*min=*/-30.0f, /*max=*/90.0f, + /*histogram=*/{}, /*calib_opts=*/calib_opts); statistics = CalibratorSingleton::GetStatistics(/*id=*/"7"); @@ -205,12 +198,6 @@ TEST(CalibratorSingletonTest, SimpleAverageMinMax) { EXPECT_EQ(statistics.value().average_min_max_statistics().num_samples(), 3); } -TEST(CalibratorSingletonTest, IssueNewIdGeneratesNewId) { - const int64_t id = CalibratorSingleton::IssueNewId(); - const int64_t next_id = CalibratorSingleton::IssueNewId(); - EXPECT_NE(id, next_id); -} - } // namespace } // namespace calibrator } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.cc index b87271f076f3bd..66d932a44f6179 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.cc @@ -12,32 +12,62 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS +#include #include +#include "absl/status/status.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tsl/platform/errors.h" namespace tensorflow { +namespace { +using ::stablehlo::quantization::CalculateBinIndexSafe; +using ::stablehlo::quantization::CalculateBinWidth; +using ::stablehlo::quantization::CalculateLowerBound; using ::stablehlo::quantization::CalibrationOptions; +using ::stablehlo::quantization::GetNumBins; +using CPUDevice = ::Eigen::ThreadPoolDevice; +using CalibrationMethod = + ::stablehlo::quantization::CalibrationOptions_CalibrationMethod; + +} // namespace REGISTER_OP("CustomAggregator") .Input("input: float") .Output("output: float") + .Output("min: float") + .Output("max: float") + .Output("histogram: int64") .Attr("id: string") .Attr("calibration_method: int = 0") .Attr("initial_num_bins: int = 0") .Attr("min_percentile: float = 0.0") .Attr("max_percentile: float = 0.0") - .SetIsStateful() .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); - return OkStatus(); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + + const tensorflow::AttrValue* calibration_method_attr; + TF_RETURN_IF_ERROR( + c->GetAttr("calibration_method", &calibration_method_attr)); + int32_t num_bins = GetNumBins( + static_cast(calibration_method_attr->i())); + c->set_output(3, c->MakeShape({num_bins})); + + return absl::OkStatus(); }); class CustomAggregatorOp : public OpKernel { @@ -45,20 +75,29 @@ class CustomAggregatorOp : public OpKernel { explicit CustomAggregatorOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("id", &id_)); + + int calibration_method_value; int initial_num_bins; - int calibration_method; float min_percentile; float max_percentile; - OP_REQUIRES_OK( - context, context->GetAttr("calibration_method", (&calibration_method))); + OP_REQUIRES_OK(context, context->GetAttr("calibration_method", + &calibration_method_value)); OP_REQUIRES_OK(context, context->GetAttr("initial_num_bins", &initial_num_bins)); OP_REQUIRES_OK(context, context->GetAttr("min_percentile", &min_percentile)); OP_REQUIRES_OK(context, context->GetAttr("max_percentile", &max_percentile)); - calib_opts_.set_calibration_method( - static_cast(calibration_method)); + + auto calibration_method = + static_cast(calibration_method_value); + OP_REQUIRES( + context, + calibration_method != + CalibrationOptions::CALIBRATION_METHOD_UNSPECIFIED, + absl::AbortedError("The calibration method must be specified.")); + + calib_opts_.set_calibration_method(calibration_method); calib_opts_.mutable_calibration_parameters()->set_initial_num_bins( initial_num_bins); calib_opts_.mutable_calibration_parameters()->set_min_percentile( @@ -70,26 +109,59 @@ class CustomAggregatorOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input_tensor = context->input(0); - auto input_flat = input_tensor.flat(); + // Use the same input for the first output. + context->set_output(0, input_tensor); + + // Calculate min/max statistics. + const auto input_flat = input_tensor.flat(); + Tensor *min_output = nullptr, *max_output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output("min", {}, &min_output)); + OP_REQUIRES_OK(context, context->allocate_output("max", {}, &max_output)); + min_output->scalar().device( + context->template eigen_device()) = input_flat.minimum(); + max_output->scalar().device( + context->template eigen_device()) = input_flat.maximum(); - const int N = input_flat.size(); - if (N == 0) { - // Use the same input for the output. - context->set_output(0, input_tensor); - return; + // Calculate histogram statistics. + int32_t num_bins = GetNumBins(calib_opts_.calibration_method()); + Tensor* histogram_output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output("histogram", {num_bins}, + &histogram_output)); + if (num_bins > 0) { + const float min_value = min_output->scalar()(); + const float max_value = max_output->scalar()(); + CalculateHistogramStatistics(context, input_tensor, min_value, max_value, + num_bins, histogram_output); } // By passing calib_opts_ and input_tensor to CalibratorSingleton, // CalibrationStatisticsCollector can calculate statistics for calibration. - calibrator::CalibratorSingleton::Report(id_, input_tensor, calib_opts_); - - // Use the same input for the output. - context->set_output(0, input_tensor); + calibrator::CalibratorSingleton::Report(id_, *min_output, *max_output, + *histogram_output, calib_opts_); } private: std::string id_; CalibrationOptions calib_opts_; + + void CalculateHistogramStatistics(OpKernelContext* context, + const Tensor& input_tensor, + const float min_value, + const float max_value, + const int32_t num_bins, + Tensor* histogram_tensor) { + const auto input_flat = input_tensor.flat(); + auto histogram_flat = histogram_tensor->flat(); + histogram_flat.setZero(); + + const float bin_width = CalculateBinWidth(min_value, max_value, num_bins); + const float lower_bound = CalculateLowerBound(min_value, bin_width); + for (int i = 0; i < input_flat.size(); ++i) { + int32_t bin_index = CalculateBinIndexSafe( + input_flat.data()[i], lower_bound, bin_width, num_bins); + histogram_flat.data()[bin_index] += 1; + } + } }; REGISTER_KERNEL_BUILDER(Name("CustomAggregator").Device(DEVICE_CPU), diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.py b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.py deleted file mode 100644 index e7e8a5de064426..00000000000000 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Custom Aggregator op is for collecting numeric metrics from the given input.""" - -from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import custom_aggregator_op_wrapper -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import load_library -from tensorflow.python.platform import resource_loader - -_custom_aggregator_op = load_library.load_op_library( - resource_loader.get_path_to_datafile('_custom_aggregator_op.so')) - - -def custom_aggregator(input_tensor, tensor_id: str): - """Creates custom aggregator op that collects numeric metrics from the tensor. - - Args: - input_tensor: Tensor to be scanned through this operator. This tensor will - be bypassed to the output tensor of this operator. - tensor_id: String, the identity of the tensor to be scanned. - - Returns: - A `Tensor` of the same value as `input_tensor`. - - Raises: - ValueError: If the given type of `input_tensor` is not float32. - """ - if input_tensor.dtype != dtypes.float32: - raise ValueError('Custom aggregator op only accept float32 values.') - return custom_aggregator_op_wrapper.custom_aggregator(input_tensor, tensor_id) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py index 4cda958f398ac6..5940803f470117 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py @@ -46,10 +46,14 @@ def testBypassAndMinMax(self): aggregator = custom_aggregator_op_wrapper.custom_aggregator( input_tensor, - '1', + id='1', calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX, ) - self.assertAllEqual(self.evaluate(aggregator), [1.0, 2.0, 3.0, 4.0, 5.0]) + aggregator_output = self.evaluate(aggregator) + self.assertAllEqual(aggregator_output.output, [1.0, 2.0, 3.0, 4.0, 5.0]) + self.assertEqual(aggregator_output.min, 1.0) + self.assertEqual(aggregator_output.max, 5.0) + self.assertEmpty(aggregator_output.histogram) statistics: calib_stat_pb2.CalibrationStatistics = ( pywrap_calibration.get_statistics_from_calibrator('1') @@ -71,7 +75,12 @@ def testTwoIdentities(self): '2', calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX, ) - self.assertAllEqual(self.evaluate(aggregator1), [1.0, 2.0, 3.0, 4.0, 5.0]) + aggregator1_output = self.evaluate(aggregator1) + self.assertAllEqual(aggregator1_output.output, [1.0, 2.0, 3.0, 4.0, 5.0]) + self.assertEqual(aggregator1_output.min, 1.0) + self.assertEqual(aggregator1_output.max, 5.0) + self.assertEmpty(aggregator1_output.histogram) + input_tensor2 = array_ops.constant( [-1.0, -2.0, -3.0, -4.0, -5.0], dtypes.float32 ) @@ -80,9 +89,13 @@ def testTwoIdentities(self): '3', calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX, ) + aggregator2_output = self.evaluate(aggregator2) self.assertAllEqual( - self.evaluate(aggregator2), [-1.0, -2.0, -3.0, -4.0, -5.0] + aggregator2_output.output, [-1.0, -2.0, -3.0, -4.0, -5.0] ) + self.assertEqual(aggregator2_output.min, -5.0) + self.assertEqual(aggregator2_output.max, -1.0) + self.assertEmpty(aggregator2_output.histogram) statistics: calib_stat_pb2 = ( pywrap_calibration.get_statistics_from_calibrator('2') @@ -108,7 +121,12 @@ def testClearData(self): '4', calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX, ) - self.assertAllEqual(self.evaluate(aggregator1), [1.0, 2.0, 3.0, 4.0, 5.0]) + aggregator1_output = self.evaluate(aggregator1) + self.assertAllEqual(aggregator1_output.output, [1.0, 2.0, 3.0, 4.0, 5.0]) + self.assertEqual(aggregator1_output.min, 1.0) + self.assertEqual(aggregator1_output.max, 5.0) + self.assertEmpty(aggregator1_output.histogram) + input_tensor2 = array_ops.constant( [-1.0, -2.0, -3.0, -4.0, -5.0], dtypes.float32 ) @@ -117,9 +135,13 @@ def testClearData(self): '5', calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX, ) + aggregator2_output = self.evaluate(aggregator2) self.assertAllEqual( - self.evaluate(aggregator2), [-1.0, -2.0, -3.0, -4.0, -5.0] + aggregator2_output.output, [-1.0, -2.0, -3.0, -4.0, -5.0] ) + self.assertEqual(aggregator2_output.min, -5.0) + self.assertEqual(aggregator2_output.max, -1.0) + self.assertEmpty(aggregator2_output.histogram) statistics: calib_stat_pb2 = ( pywrap_calibration.get_statistics_from_calibrator('4') @@ -157,10 +179,15 @@ def testBypassAndAverageMinMax(self): '6', calibration_method=_CalibrationMethod.CALIBRATION_METHOD_AVERAGE_MIN_MAX, ) + aggregator1_output = self.evaluate(aggregator1) self.assertAllEqual( - self.evaluate(aggregator1), + aggregator1_output.output, [-50.0, -25.0, 0.0, 25.0, 50.0], ) + self.assertEqual(aggregator1_output.min, -50.0) + self.assertEqual(aggregator1_output.max, 50.0) + self.assertEmpty(aggregator1_output.histogram) + input_tensor2 = array_ops.constant( [-100.0, -50.0, 0.0, 50.0, 100.0], dtypes.float32 ) @@ -169,9 +196,13 @@ def testBypassAndAverageMinMax(self): '6', calibration_method=_CalibrationMethod.CALIBRATION_METHOD_AVERAGE_MIN_MAX, ) + aggregator2_output = self.evaluate(aggregator2) self.assertAllEqual( - self.evaluate(aggregator2), [-100.0, -50.0, 0.0, 50.0, 100.0] + aggregator2_output.output, [-100.0, -50.0, 0.0, 50.0, 100.0] ) + self.assertEqual(aggregator2_output.min, -100.0) + self.assertEqual(aggregator2_output.max, 100.0) + self.assertEmpty(aggregator2_output.histogram) statistics: calib_stat_pb2 = ( pywrap_calibration.get_statistics_from_calibrator('6') @@ -183,6 +214,31 @@ def testBypassAndAverageMinMax(self): self.assertAllEqual((min_sum, max_sum, num_samples), (-150.0, 150.0, 2)) + def testHistogramCalibration(self): + with self.session(): + pywrap_calibration.clear_calibrator() + input_tensor = array_ops.constant( + [1.0, 1.0, 3.0, 4.0, 6.0], dtypes.float32 + ) + + aggregator = custom_aggregator_op_wrapper.custom_aggregator( + input_tensor, + id='7', + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, + initial_num_bins=256, + ) + aggregator_output = self.evaluate(aggregator) + self.assertAllEqual(aggregator_output.output, [1.0, 1.0, 3.0, 4.0, 6.0]) + self.assertEqual(aggregator_output.min, 1.0) + self.assertEqual(aggregator_output.max, 6.0) + + self.assertLen(aggregator_output.histogram, 512) + self.assertEqual(sum(aggregator_output.histogram), 5) + self.assertEqual(aggregator_output.histogram[0], 2) + self.assertEqual(aggregator_output.histogram[128], 1) + self.assertEqual(aggregator_output.histogram[192], 1) + self.assertEqual(aggregator_output.histogram[320], 1) + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD index 23ce2105634854..218e229828211a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD @@ -173,6 +173,7 @@ tf_cc_test( srcs = ["constant_fold_test.cc"], deps = [ ":constant_fold", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", "//tensorflow/compiler/mlir/quantization/common:test_base", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:tensorflow", @@ -184,7 +185,6 @@ tf_cc_test( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc index 16d3f18364efbf..aaaf088b507e07 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/common/test_base.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc index 0a5bac34b50c3a..e4229cb97bf45a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc @@ -80,7 +80,7 @@ class ConvertCustomAggregationOpToQuantStats // When there are no min and max attributes, remove op. if (min == nullptr || max == nullptr) { - op->replaceAllUsesWith(op->getOperands()); + op.getOutput().replaceAllUsesWith(op.getInput()); rewriter.eraseOp(op); return success(); } @@ -93,8 +93,9 @@ class ConvertCustomAggregationOpToQuantStats ElementsAttr axis_stats; IntegerAttr axis; - rewriter.replaceOpWithNewOp( - op, op->getOperand(0), layer_stats, axis_stats, axis); + quantfork::StatisticsOp stats_op = rewriter.create( + op->getLoc(), op.getInput(), layer_stats, axis_stats, axis); + op.getOutput().replaceAllUsesWith(stats_op.getResult()); return success(); } }; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.td index 9d39d89c42ae53..03e7b18569d7a2 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.td @@ -26,7 +26,7 @@ def GetBatchFunctionOpArgOperands: // because `TF_BatchFunctionOp` doesn't have the `CallOpInterface` trait. def ReplaceBatchFunctionOpToPartitionedCallOp : Pat< (TF_BatchFunctionOp:$src_op_res - $_, $_, $f, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_), + $_, $_, $f, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_), (TF_PartitionedCallOp (GetBatchFunctionOpArgOperands $src_op_res), $f, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc index 56b9d7393aacfd..5ed89d89339571 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include @@ -35,6 +36,7 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" @@ -268,14 +270,21 @@ class AddCustomAggregationOp : public RewritePattern { calib_opts_.calibration_parameters().max_percentile())), }; + int32_t num_bins = GetNumBins(calib_opts_.calibration_method()); + SmallVector output_types{ + value.getType(), + RankedTensorType::get({}, rewriter.getF32Type()), + RankedTensorType::get({}, rewriter.getF32Type()), + RankedTensorType::get({num_bins}, rewriter.getI64Type()), + }; + // Insert custom aggregation op between operand and operator. rewriter.setInsertionPointAfterValue(value); Operation *aggregator_op = rewriter.create( - op->getLoc(), value.getType(), value, attributes); + op->getLoc(), output_types, value, attributes); Value aggregator_op_result = aggregator_op->getOpResult(0); - value.replaceAllUsesWith(aggregator_op_result); - aggregator_op->replaceUsesOfWith(aggregator_op_result, value); + value.replaceAllUsesExcept(aggregator_op_result, aggregator_op); } return success(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.td index 559c5e31a71f09..e33e226be35515 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.td @@ -65,7 +65,10 @@ def TF_CustomAggregatorOp : TF_Op<"CustomAggregator", [Pure]> { ); let results = (outs - TensorOf<[TF_Float32]>:$output + TensorOf<[TF_Float32]>:$output, + TensorOf<[TF_Float32]>:$min, + TensorOf<[TF_Float32]>:$max, + TensorOf<[TF_Int64]>:$histogram ); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index bd53b29ad79255..78a8321f9f87d4 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -51,10 +51,11 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_saver_op", # Required for CalibrationStatisticsSaver op registration. "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op", # Required for CustomAggregator op registration. "//tensorflow/compiler/mlir/quantization/tensorflow/cc:convert_asset_args", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", - "//tensorflow/compiler/mlir/quantization/tensorflow/debugging:dump_tensor_op", + "//tensorflow/compiler/mlir/quantization/tensorflow/debugging:dump_tensor_op", # Required for DumpTensor op registration. "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", @@ -108,7 +109,6 @@ pytype_strict_library( "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_algorithm", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_py", - "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:pywrap_calibration", "//tensorflow/core:protos_all_py", "//tensorflow/python/client:session", "//tensorflow/python/eager:context", @@ -120,7 +120,6 @@ pytype_strict_library( "//tensorflow/python/saved_model:loader", "//tensorflow/python/trackable:autotrackable", "//tensorflow/python/types:core", - "//third_party/py/numpy", "@absl_py//absl/logging", ], ) @@ -129,13 +128,7 @@ tf_py_strict_test( name = "py_function_lib_py_test", srcs = ["py_function_lib_test.py"], main = "py_function_lib_test.py", - deps = [ - ":py_function_lib_py", - ":pywrap_function_lib", - "//tensorflow:tensorflow_py", - "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_py", - "//tensorflow/python/platform:client_testlib", - ], + deps = ["//tensorflow/python/platform:client_testlib"], ) cc_library( @@ -220,6 +213,7 @@ tf_python_pybind_extension( "//tensorflow/core:protos_all_cc", "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@pybind11", ], @@ -276,13 +270,11 @@ pytype_strict_library( "//tensorflow/python/framework:importer", "//tensorflow/python/framework:ops", "//tensorflow/python/lib/io:file_io", - "//tensorflow/python/ops:variables", "//tensorflow/python/saved_model:builder", "//tensorflow/python/saved_model:constants", "//tensorflow/python/saved_model:loader", "//tensorflow/python/saved_model:tag_constants", "//tensorflow/python/training:saver", - "//tensorflow/python/training:training_lib", "@absl_py//absl/logging", ], ) @@ -421,7 +413,6 @@ tf_py_strict_test( "//tensorflow/python/saved_model:save", "//tensorflow/python/saved_model:tag_constants", "//tensorflow/python/trackable:autotrackable", - "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py index ec86deac1b497d..08ff75ac802613 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py @@ -6211,25 +6211,25 @@ class CalibrationOptionsTest(quantize_model_test_base.QuantizedModelTest): stablehlo_quant_config_pb2.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, calibration_parameters=stablehlo_quant_config_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + initial_num_bins=32, ), ), stablehlo_quant_config_pb2.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, calibration_parameters=stablehlo_quant_config_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + initial_num_bins=32, ), ), stablehlo_quant_config_pb2.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, calibration_parameters=stablehlo_quant_config_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + initial_num_bins=32, ), ), stablehlo_quant_config_pb2.CalibrationOptions( calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, calibration_parameters=stablehlo_quant_config_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + initial_num_bins=32, ), ), ], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h index 4be78b1dc74bf6..fbba72479805d6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_PY_FUNCTION_LIB_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_PY_FUNCTION_LIB_H_ +#include #include #include #include @@ -46,10 +47,12 @@ class PyFunctionLibrary { // `add_meta_graph_and_variables` function, which is internally used to add a // `MetaGraphDef` to save to the SavedModel. // + // Returns `true` if successful. Returns `std::nullopt` otherwise. + // // If the function signature changes, likely its corresponding .pyi type // hinting and definition should also change. - // LINT.IfChange - virtual void SaveExportedModel( + // LINT.IfChange(save_exported_model) + virtual std::optional SaveExportedModel( absl::string_view dst_saved_model_path, const ExportedModel& exported_model, absl::string_view src_saved_model_path, @@ -70,18 +73,15 @@ class PyFunctionLibrary { // of type `RepresentativeDatasetOrMapping`, which is used to run the // calibration. // - // Returns the updated exported model where the collected calibration - // statistics are added to `CustomAggregator` nodes at the `min` and `max` - // attributes. + // Returns `true` if successful. Returns `std::nullopt` otherwise. // // If the function signature changes, likely its corresponding .pyi type // hinting and definition should also change. // LINT.IfChange(run_calibration) - virtual void RunCalibration( + virtual std::optional RunCalibration( absl::string_view saved_model_path, const std::vector& signature_keys, const std::unordered_set& tags, - const ::stablehlo::quantization::CalibrationOptions& calibration_options, bool force_graph_mode_calibration, const absl::flat_hash_map& representative_dataset_file_map) const = 0; @@ -93,14 +93,16 @@ class PyFunctionLibrary { // Retrieves min and max value from `calibration_statistics`, based on the // calibration method specified by `calibration_options`. // + // Returns `std::nullopt` if unsuccessful. + // // If the function signature changes, likely its corresponding .pyi type // hinting and definition should also change. // LINT.IfChange(get_calibration_min_max_value) - virtual stablehlo::quantization::MinMaxValue GetCalibrationMinMaxValue( - const tensorflow::calibrator::CalibrationStatistics& - calibration_statistics, - const ::stablehlo::quantization::CalibrationOptions& calibration_options) - const = 0; + virtual std::optional + GetCalibrationMinMaxValue(const tensorflow::calibrator::CalibrationStatistics& + calibration_statistics, + const ::stablehlo::quantization::CalibrationOptions& + calibration_options) const = 0; // LINT.ThenChange( // pywrap_function_lib.pyi:get_calibration_min_max_value, // py_function_lib.py:get_calibration_min_max_value, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py index 00f261f4a66c7c..f630138f81fca1 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py @@ -15,7 +15,9 @@ """Defines a wrapper class for overridden python method definitions.""" from collections.abc import Callable, Collection, Mapping, Sequence -from typing import Optional +import functools +import traceback +from typing import Optional, TypeVar from absl import logging @@ -45,6 +47,11 @@ _ASSETS_DIR = 'assets' _ASSETS_EXTRA_DIR = 'assets.extra' +# Type variable for a type that is not `None`. This represents a return value of +# methods in `PyFunctionLibrary` that should not be `None`, as `None` represents +# that the execution was unsucessful, transfored as `std::nullopt_t` from c++. +NotNoneT = TypeVar('NotNoneT') + def _get_saver_def_or_none( exported_model: exported_model_pb2.ExportedModel, @@ -502,6 +509,117 @@ def _run_graph_for_calibration( logging.info('Calibration step complete.') +def _run_calibration( + saved_model_path: str, + signature_keys: Sequence[str], + tags: Collection[str], + force_graph_mode_calibration: bool, + representative_dataset_file_map: Mapping[ + str, quantization_options_pb2.RepresentativeDatasetFile + ], +) -> bool: + """Runs calibration and adds calibration statistics to exported model. + + Args: + saved_model_path: Path to the SavedModel to run calibration. + signature_keys: List of signature keys corresponding to SignatureDefs to run + calibration on. + tags: A set of tags that identify the MetaGraphDef. + force_graph_mode_calibration: If True, runs the calibration in graph mode. + representative_dataset_file_map: Signature key -> + `RepresentativeDatasetFile` mapping for running the calibration step. Each + dataset file stores the representative dataset for the function matching + the signature key. + + Returns: + `True` upon successfully running calibration. + """ + repr_dataset_map = rd.TfRecordRepresentativeDatasetLoader( + representative_dataset_file_map + ).load() + + # Uses the representative dataset to collect statistics for calibration. + # After this operation, min & max values are stored separately in a global + # CalibratorSingleton instance. + _run_graph_for_calibration( + saved_model_path, + signature_keys, + tags, + repr_dataset_map, + force_graph_mode_calibration, + ) + + # Dummy value to indicate successful run, as `None` would indicate error. See + # comments in `NotNoneT`. + return True + + +def _call_and_return_none_on_error( + func: Callable[[], NotNoneT], error_msg: str +) -> Optional[NotNoneT]: + """Calls `func` and returns `None` on error. + + This is used to gracefully return the 'error status' represented as `None`, as + raising exceptions from `PyFunctionLibrary` methods crashes the program. + + Args: + func: The function to run. The function should be a callable returning a + non-None value. + error_msg: The error message to log upon error. Used for debugging purposes. + + Returns: + `None` if the function raises an exception. The return value of `func` + otherwise. + """ + try: + return func() + except Exception as ex: # pylint: disable=broad-exception-caught; Required for graceful failing with pybind11. + # Prints the exception traceback for debuggability. + traceback.print_exception(ex) + # Additional error log for debuggability. + logging.error(error_msg) + return None + + +def _save_model_and_copy_assets( + exported_model: exported_model_pb2.ExportedModel, + src_saved_model_path: str, + dst_saved_model_path: str, + signature_def_map: Mapping[str, meta_graph_pb2.SignatureDef], + tags: Collection[str], +) -> bool: + """Saves the model and copies the assets from the source model. + + Args: + exported_model: ExportedModel to save. + src_saved_model_path: Path to the source SavedModel. This will be used to + copy the asset files to `dst_saved_model_path`. + dst_saved_model_path: Destination path to save the exported model. + signature_def_map: Signature key -> SignatureDef mapping. + tags: Tags to attach to the saved MetaGraphDef. + + Returns: + `True` upon successfully saving the model. + """ + save_model.save_model_v1( + exported_model.graph_def, + dst_saved_model_path, + signature_def_map, + tags, + init_op_name=exported_model.init_node_name, + saver_def=_get_saver_def_or_none(exported_model), + checkpoint_dir=exported_model.checkpoint_dir, + function_aliases=exported_model.function_aliases, + asset_file_defs=exported_model.asset_file_defs, + ) + + _copy_assets(src_saved_model_path, dst_saved_model_path) + + # Dummy value to indicate successful run, as `None` would indicate error. See + # comments in `NotNoneT`. + return True + + class PyFunctionLibrary(pywrap_function_lib.PyFunctionLibrary): """Wrapper class for overridden python method definitions. @@ -517,7 +635,7 @@ def save_exported_model( src_saved_model_path: str, tags: set[str], serialized_signature_def_map: dict[str, bytes], - ) -> None: + ) -> Optional[bool]: # LINT.ThenChange(py_function_lib.h:save_exported_model) """Saves `ExportedModel` to `dst_saved_model_path` as a SavedModel. @@ -528,6 +646,10 @@ def save_exported_model( copy the asset files to `dst_saved_model_path`. tags: Tags to attach to the saved MetaGraphDef. serialized_signature_def_map: Signature key -> serialized SignatureDef. + + Returns: + `True` upon successful execution. `None` when an error is raised + internally. """ exported_model = exported_model_pb2.ExportedModel.FromString( exported_model_serialized @@ -540,20 +662,21 @@ def save_exported_model( serialized_signature_def ) - save_model.save_model_v1( - exported_model.graph_def, - dst_saved_model_path, - signature_def_map, - tags, - init_op_name=exported_model.init_node_name, - saver_def=_get_saver_def_or_none(exported_model), - checkpoint_dir=exported_model.checkpoint_dir, - function_aliases=exported_model.function_aliases, - asset_file_defs=exported_model.asset_file_defs, + return _call_and_return_none_on_error( + func=functools.partial( + _save_model_and_copy_assets, + exported_model, + src_saved_model_path, + dst_saved_model_path, + signature_def_map, + tags, + ), + error_msg=( + f'Failed to save model "{dst_saved_model_path}",' + f' signature_def_map: {signature_def_map}, tags: {tags}.' + ), ) - _copy_assets(src_saved_model_path, dst_saved_model_path) - # TODO: b/311097139 - Extract calibration related functions into a separate # file. # LINT.IfChange(run_calibration) @@ -562,10 +685,9 @@ def run_calibration( saved_model_path: str, signature_keys: list[str], tags: set[str], - calibration_options_serialized: bytes, force_graph_mode_calibration: bool, representative_dataset_file_map_serialized: dict[str, bytes], - ) -> None: + ) -> Optional[bool]: # LINT.ThenChange(py_function_lib.h:run_calibration) """Runs calibration and adds calibration statistics to exported model. @@ -574,7 +696,6 @@ def run_calibration( signature_keys: List of signature keys corresponding to SignatureDefs to run calibration on. tags: A set of tags that identify the MetaGraphDef. - calibration_options_serialized: Serialized `CalibrationOptions`. force_graph_mode_calibration: If True, runs the calibration in graph mode. representative_dataset_file_map_serialized: Signature key -> `RepresentativeDatasetFile` mapping for running the calibration step. @@ -582,10 +703,9 @@ def run_calibration( matching the signature key. Returns: - Updated exported model (serialized) where the collected calibration - statistics are added to `CustomerAggregator` nodes at the `min` and `max` - attributes. + The error message if the function raises and exception. `None` otherwise. """ + # Deserialize `RepresentativeDatasetFile` values. dataset_file_map = {} for ( signature_key, @@ -597,19 +717,19 @@ def run_calibration( ) ) - repr_dataset_map = rd.TfRecordRepresentativeDatasetLoader( - dataset_file_map=dataset_file_map - ).load() - - # Uses the representative dataset to collect statistics for calibration. - # After this operation, min & max values are stored separately in a global - # CalibratorSingleton instance. - _run_graph_for_calibration( - saved_model_path, - signature_keys, - tags, - repr_dataset_map, - force_graph_mode_calibration, + return _call_and_return_none_on_error( + func=functools.partial( + _run_calibration, + saved_model_path, + signature_keys, + tags, + force_graph_mode_calibration, + dataset_file_map, + ), + error_msg=( + f'Failed to run calibration on model "{saved_model_path}",' + f' signature_keys: {signature_keys}, tags: {tags}.' + ), ) # LINT.IfChange(get_calibration_min_max_value) @@ -617,7 +737,7 @@ def get_calibration_min_max_value( self, calibration_statistics_serialized: bytes, calibration_options_serialized: bytes, - ) -> tuple[float, float]: + ) -> Optional[tuple[float, float]]: """Calculates min and max values from statistics. Args: @@ -627,17 +747,26 @@ def get_calibration_min_max_value( how the min / max should be calculated. Returns: - (min_value, max_value): Min and max calculated using calib_opts. - - Raises: - ValueError: Unsupported calibration method is given. + (min_value, max_value): Min and max calculated using calib_opts. `None` + upon error. """ # LINT.ThenChange(py_function_lib.h:get_calibration_min_max_value) - return calibration_algorithm.get_min_max_value( - calibration_statistics_pb2.CalibrationStatistics.FromString( - calibration_statistics_serialized + + # Deserialize values passed from c++. + statistics = calibration_statistics_pb2.CalibrationStatistics.FromString( + calibration_statistics_serialized + ) + options = stablehlo_quant_config_pb2.CalibrationOptions.FromString( + calibration_options_serialized + ) + + return _call_and_return_none_on_error( + functools.partial( + calibration_algorithm.get_min_max_value, + statistics, + options, ), - stablehlo_quant_config_pb2.CalibrationOptions.FromString( - calibration_options_serialized + error_msg=( + f'Retrieving calibrated min / max failed. Options: {options}.' ), ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc index 9e8ecd2352ae94..fc181edb8a75f5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "pybind11/cast.h" // from @pybind11 #include "pybind11/detail/common.h" // from @pybind11 @@ -53,35 +54,36 @@ class PyFunctionLibraryTrampoline : public PyFunctionLibrary { public: using PyFunctionLibrary::PyFunctionLibrary; - void SaveExportedModel(const absl::string_view dst_saved_model_path, - const ExportedModel& exported_model, - const absl::string_view src_saved_model_path, - const std::unordered_set& tags, - const absl::flat_hash_map& - signature_def_map) const override { - PYBIND11_OVERRIDE_PURE(void, PyFunctionLibrary, save_exported_model, - dst_saved_model_path, exported_model, - src_saved_model_path, tags, signature_def_map); + std::optional SaveExportedModel( + const absl::string_view dst_saved_model_path, + const ExportedModel& exported_model, + const absl::string_view src_saved_model_path, + const std::unordered_set& tags, + const absl::flat_hash_map& signature_def_map) + const override { + PYBIND11_OVERRIDE_PURE(std::optional, PyFunctionLibrary, + save_exported_model, dst_saved_model_path, + exported_model, src_saved_model_path, tags, + signature_def_map); } - void RunCalibration( + std::optional RunCalibration( const absl::string_view saved_model_path, const std::vector& signature_keys, const std::unordered_set& tags, - const CalibrationOptions& calibration_options, const bool force_graph_mode_calibration, const absl::flat_hash_map& representative_dataset_file_map) const override { - PYBIND11_OVERRIDE_PURE(void, PyFunctionLibrary, run_calibration, - saved_model_path, signature_keys, tags, - calibration_options, force_graph_mode_calibration, + PYBIND11_OVERRIDE_PURE(std::optional, PyFunctionLibrary, + run_calibration, saved_model_path, signature_keys, + tags, force_graph_mode_calibration, representative_dataset_file_map); } - MinMaxValue GetCalibrationMinMaxValue( + std::optional GetCalibrationMinMaxValue( const CalibrationStatistics& calibration_statistics, const CalibrationOptions& calibration_options) const override { - PYBIND11_OVERRIDE_PURE(MinMaxValue, PyFunctionLibrary, + PYBIND11_OVERRIDE_PURE(std::optional, PyFunctionLibrary, get_calibration_min_max_value, calibration_statistics, calibration_options); } @@ -100,8 +102,7 @@ PYBIND11_MODULE(pywrap_function_lib, m) { py::arg("serialized_signature_def_map")) .def("run_calibration", &PyFunctionLibrary::RunCalibration, py::arg("saved_model_path"), py::arg("signature_keys"), - py::arg("tags"), py::arg("calibration_options_serialized"), - py::arg("force_graph_mode_calibration"), + py::arg("tags"), py::arg("force_graph_mode_calibration"), py::arg("representative_dataset_file_map_serialized")) .def("get_calibration_min_max_value", &PyFunctionLibrary::GetCalibrationMinMaxValue, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi index d8c9ed4d9be79e..8e4a7cee6203c7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import Any +from typing import Any, Optional class PyFunctionLibrary: @@ -24,7 +24,7 @@ class PyFunctionLibrary: src_saved_model_path: str, tags: set[str], serialized_signature_def_map: dict[str, bytes], - ) -> None: ... + ) -> Optional[bool]: ... # LINT.ThenChange() # LINT.IfChange(run_calibration) @@ -33,11 +33,10 @@ class PyFunctionLibrary: saved_model_path: str, signature_keys: list[str], tags: set[str], - calibration_options_serialized: bytes, force_graph_mode_calibration: bool, # Value type: RepresentativeDatasetFile. representative_dataset_file_map_serialized: dict[str, bytes], - ) -> None: ... + ) -> Optional[bool]: ... # LINT.ThenChange() # LINT.IfChange(get_calibration_min_max_value) @@ -45,5 +44,5 @@ class PyFunctionLibrary: self, calibration_statistics_serialized: bytes, calibration_options_serialized: bytes, - ) -> tuple[float, float]: ... + ) -> Optional[tuple[float, float]]: ... # LINT.ThenChange() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc index a0865c44664290..6f5db2c5a823e8 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc @@ -223,8 +223,8 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { if (!exported_model.ok()) return exported_model.status(); // Remove the `tpu` tag from the debug quantized saved model as it is - // for CPU. Note the 'tpu' value should be the same as `TPU` defined in - // tensorflow/python/saved_model/tag_constants.py. + // for CPU. Note the 'tpu' value should be the same as `TPU` defined + // in tensorflow/python/saved_model/tag_constants.py. if (quantization_options.has_debugger_config()) { tags.erase("tpu"); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index 89467d30944ca9..9f4621360e2e89 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -437,7 +437,6 @@ absl::StatusOr QuantizeStaticRangePtq( py_function_library.RunCalibration( *precalibrated_saved_model_dir, signature_keys, tags, - quantization_options.calibration_options(), quantization_options.force_graph_mode_calibration(), representative_dataset_file_map_serialized); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc index b957ffe469a004..c3f5c32bdd9720 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc @@ -47,7 +47,7 @@ absl::Status UnfreezeConstantsAndSaveVariables( }, ctx, module_op)); - if (const tsl::Status create_dir_status = + if (const absl::Status create_dir_status = Env::Default()->CreateDir(std::string(checkpoint_dir)); !create_dir_status.ok()) { LOG(ERROR) << "Failed to create checkpoint directory at: " diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_custom_aggregation_op_to_quant_stats.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_custom_aggregation_op_to_quant_stats.mlir index 02a348b8c8fe3e..f72c9f3388c071 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_custom_aggregation_op_to_quant_stats.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_custom_aggregation_op_to_quant_stats.mlir @@ -1,19 +1,19 @@ // RUN: tf-quant-opt %s -quant-convert-tf-custom-aggregator-op-to-quant-stats | FileCheck %s func.func @customAggregator(%arg0: tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>) { - %0 = "tf.CustomAggregator"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, id = "0"} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - %1 = "tf.CustomAggregator"(%arg0) {id = "1"} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - func.return %0, %1 : tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32> + %0:4 = "tf.CustomAggregator"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, id = "0"} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + %1:4 = "tf.CustomAggregator"(%arg0) {id = "1"} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + func.return %0#0, %1#0 : tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32> } // CHECK: func @customAggregator // CHECK-NEXT: %[[stats:.*]] = "quantfork.stats"(%arg0) {layerStats = dense<[-1.000000e-01, 2.000000e-01]> : tensor<2xf32>} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: return %[[stats]], %arg0 func.func @doNotHandleNoMinMaxCases(%arg0: tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>) { - %0 = "tf.CustomAggregator"(%arg0) {min = -0.1 : f32, id = "1"} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - %1 = "tf.CustomAggregator"(%arg0) {max = 0.2 : f32, id = "2"} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - %2 = "tf.CustomAggregator"(%arg0) {id = "3"} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - func.return %0, %1, %2 : tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32> + %0:4 = "tf.CustomAggregator"(%arg0) {min = -0.1 : f32, id = "1"} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + %1:4 = "tf.CustomAggregator"(%arg0) {max = 0.2 : f32, id = "2"} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + %2:4 = "tf.CustomAggregator"(%arg0) {id = "3"} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + func.return %0#0, %1#0, %2#0 : tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32> } // CHECK: func @doNotHandleNoMinMaxCases // CHECK-NOT: "quantfork.stats" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir index b8ed5d5f361d36..052da55dce336d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir @@ -26,10 +26,10 @@ module { // CalibrationOptions(calibration_method=CALIBRATION_METHOD_MIN_MAX) // MIN-MAX-CHECK: func @wrap_composite_func -// MIN-MAX-CHECK-NEXT: [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> -// MIN-MAX-CHECK-NEXT: [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// MIN-MAX-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) +// MIN-MAX-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) // MIN-MAX-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) -// MIN-MAX-CHECK-NEXT: [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// MIN-MAX-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) // MIN-MAX-CHECK-NEXT: return [[res]] : tensor<*xf32> // MIN-MAX-CHECK: func @no_composite_func @@ -43,10 +43,10 @@ module { // CalibrationOptions(calibration_method=CALIBRATION_METHOD_AVERAGE_MIN_MAX) // AVERAGE-MIN-MAX-CHECK: func @wrap_composite_func -// AVERAGE-MIN-MAX-CHECK-NEXT: [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> -// AVERAGE-MIN-MAX-CHECK-NEXT: [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// AVERAGE-MIN-MAX-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) +// AVERAGE-MIN-MAX-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) // AVERAGE-MIN-MAX-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) -// AVERAGE-MIN-MAX-CHECK-NEXT: [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// AVERAGE-MIN-MAX-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) // AVERAGE-MIN-MAX-CHECK-NEXT: return [[res]] : tensor<*xf32> // AVERAGE-MIN-MAX-CHECK: func @no_composite_func @@ -63,10 +63,10 @@ module { // calibration_parameters=CalibrationParameters(initial_num_bins=256, min_percentile=0.001, max_percentile=99.999) // ) // HISTOGRAM-PERCENTILE-CHECK: func @wrap_composite_func -// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> tensor<*xf32> -// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-PERCENTILE-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) -// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-PERCENTILE-CHECK-NEXT: return [[res]] : tensor<*xf32> // HISTOGRAM-PERCENTILE-CHECK: func @no_composite_func @@ -83,10 +83,10 @@ module { // calibration_parameters=CalibrationParameters(initial_num_bins=256) // ) // HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @wrap_composite_func -// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> -// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) -// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: return [[res]] : tensor<*xf32> // HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @no_composite_func @@ -103,10 +103,10 @@ module { // calibration_parameters=CalibrationParameters(initial_num_bins=256) // ) // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @wrap_composite_func -// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> -// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) -// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: return [[res]] : tensor<*xf32> // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @no_composite_func @@ -123,10 +123,10 @@ module { // calibration_parameters=CalibrationParameters(initial_num_bins=256) // ) // HISTOGRAM-MSE-SYMMETRIC-CHECK: func @wrap_composite_func -// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> -// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) -// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: return [[res]] : tensor<*xf32> // HISTOGRAM-MSE-SYMMETRIC-CHECK: func @no_composite_func @@ -174,4 +174,4 @@ module { %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<100352x10xf32>) -> tensor return %0 : tensor } -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/issue_ids_of_custom_aggregation_ops.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/issue_ids_of_custom_aggregation_ops.mlir index 4aa1ae76b8a83d..6a1621cdf17e89 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/issue_ids_of_custom_aggregation_ops.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/issue_ids_of_custom_aggregation_ops.mlir @@ -1,17 +1,17 @@ // RUN: tf-quant-opt %s -quant-issues-ids-of-custom-aggregation-ops | FileCheck %s func.func @issue_ids(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf.CustomAggregator"(%arg1) {id = ""} : (tensor<*xf32>) -> tensor<*xf32> - %1 = "tf.CustomAggregator"(%arg0) {id = ""} : (tensor<*xf32>) -> tensor<*xf32> + %0:4 = "tf.CustomAggregator"(%arg1) {id = ""} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<*xi64>) + %1:4 = "tf.CustomAggregator"(%arg0) {id = ""} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<*xi64>) %2 = "tf.AddV2"(%1, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %3 = "tf.CustomAggregator"(%2) {id = ""} : (tensor<*xf32>) -> tensor<*xf32> + %3:4 = "tf.CustomAggregator"(%2) {id = ""} : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<*xi64>) func.return %3 : tensor<*xf32> } // CHECK: func @issue_ids -// CHECK-NEXT: [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = "0"}> : (tensor<*xf32>) -> tensor<*xf32> -// CHECK-NEXT: [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = "1"}> : (tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{id = "0"}> : (tensor<*xf32>) +// CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{id = "1"}> : (tensor<*xf32>) // CHECK-NEXT: [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> -// CHECK-NEXT: [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = "2"}> : (tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{id = "2"}> : (tensor<*xf32>) // CHECK-NEXT: return [[res]] : tensor<*xf32> diff --git a/tensorflow/compiler/mlir/stablehlo/BUILD b/tensorflow/compiler/mlir/stablehlo/BUILD index 5d5342e8a264c4..28e2f104221284 100644 --- a/tensorflow/compiler/mlir/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/stablehlo/BUILD @@ -1,4 +1,4 @@ -load("@local_tsl//tsl:tsl.default.bzl", "tsl_pybind_extension") +load("@local_xla//xla/tsl:tsl.default.bzl", "tsl_pybind_extension") load("//tensorflow:pytype.default.bzl", "pytype_strict_library") load("//tensorflow:strict.default.bzl", "py_strict_test") diff --git a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.td index 83eac78b7574d6..0c783c01caa287 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.td @@ -109,7 +109,8 @@ def TF_IfrtLoadVariableOp : TF_Op<"IfrtLoadVariable", [Pure]> { configuration specified in `VariableDeviceShardingConfigProto`. This op returns a scalar string tensor containing the loaded variable name, which can be - used as a key to look for the loaded IFRT array in runtime. + used as a key to look for the loaded IFRT array in runtime and a restored tensor, which + maybe lowered to a future by runtime. }]; let arguments = (ins @@ -119,7 +120,8 @@ def TF_IfrtLoadVariableOp : TF_Op<"IfrtLoadVariable", [Pure]> { ); let results = (outs - TF_StrTensor:$array_key + TF_StrTensor:$array_key, + TF_Tensor: $tensor_future ); TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 5e0e58c279e358..6ba660297366ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1136,6 +1136,7 @@ to be batched.}]>:$captured_tensors, DefaultValuedOptionalAttr:$low_priority_batch_timeout_micros, DefaultValuedOptionalAttr:$low_priority_allowed_batch_sizes, DefaultValuedOptionalAttr:$low_priority_max_enqueued_batches, + DefaultValuedOptionalAttr, "\"low_priority_padding_with_max_batch_size\"">:$mixed_priority_policy, DefaultValuedOptionalAttr:$enable_large_batch_splitting ); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/order_by_dialect.mlir b/tensorflow/compiler/mlir/tensorflow/tests/order_by_dialect.mlir index bb3f74d702b261..f19de0b5996999 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/order_by_dialect.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/order_by_dialect.mlir @@ -85,12 +85,12 @@ func.func @tf_and_mhlo(%arg0: tensor<32x28x28x1xf32>, %arg1: tensor>>) -> tensor<3x3x1x5xf32> %5 = "tf.ReadVariableOp"(%arg3) : (tensor>>) -> tensor<3920x10xf32> %6 = mhlo.convolution(%arg0, %4) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<32x28x28x1xf32>, tensor<3x3x1x5xf32>) -> tensor<32x28x28x5xf32> - %7 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<32x28x28x5xf32> + %7 = "mhlo.broadcast_in_dim"(%3) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<5xf32>) -> tensor<32x28x28x5xf32> %8 = mhlo.add %6, %7 : tensor<32x28x28x5xf32> %9 = mhlo.maximum %8, %1 : tensor<32x28x28x5xf32> %10 = "mhlo.reshape"(%9) : (tensor<32x28x28x5xf32>) -> tensor<32x3920xf32> %11 = "mhlo.dot"(%10, %5) : (tensor<32x3920xf32>, tensor<3920x10xf32>) -> tensor<32x10xf32> - %12 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>) -> tensor<32x10xf32> + %12 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<10xf32>) -> tensor<32x10xf32> %13 = mhlo.add %11, %12 : tensor<32x10xf32> %14 = mhlo.maximum %13, %0 : tensor<32x10xf32> return %14 : tensor<32x10xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_outline_entry_functions.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_outline_entry_functions.mlir deleted file mode 100644 index 60f767a04cbf58..00000000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/xla_outline_entry_functions.mlir +++ /dev/null @@ -1,70 +0,0 @@ -// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-xla-outline-entry-functions | FileCheck %s - -// Check that we outline the top-level functions. - -// CHECK-LABEL: func.func private @main_outlined(%arg0: tensor) -> tensor attributes {_xla_compile_device_type = "CPU", allow_soft_placement = true} { -// CHECK: %0 = "tf.StatefulPartitionedCall"(%arg0) <{config = "", config_proto = "", executor_type = "", f = @f}> {_xla_compile_device_type = "CPU"} : (tensor) -> tensor -// CHECK: %cst = "tf.Const"() <{value = dense<5> : tensor}> : () -> tensor -// CHECK: %1 = "tf.Add"(%0, %cst) : (tensor, tensor) -> tensor -// CHECK: return %1 : tensor -// CHECK: } - -// CHECK: func.func @main(%arg0: tensor) -> tensor attributes {_xla_compile_device_type = "CPU", tf.entry_function = {}} { -// CHECK: %0 = "tf.StatefulPartitionedCall"(%arg0) <{config = "", config_proto = "", executor_type = "", f = @main_outlined}> {_xla_compile_device_type = "CPU", allow_soft_placement = true} : (tensor) -> tensor -// CHECK: return %0 : tensor -// CHECK: } -func.func @main(%arg0: tensor) -> tensor attributes {_xla_compile_device_type = "CPU", allow_soft_placement = true, tf.entry_function = {}} { - %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", _xla_compile_device_type = "CPU", f = @f} : (tensor) -> (tensor) - %1 = "tf.Const"() {value = dense<5> : tensor} : () -> tensor - %2 = "tf.Add"(%0, %1) : (tensor, tensor) -> (tensor) - func.return %2 : tensor -} - -func.func @f(%arg0: tensor) -> tensor { - func.return %arg0 : tensor -} - -// ----- - -// Tests multiple entry functions. - -// CHECK-LABEL: func.func private @entry1_outlined(%arg0: tensor) -> tensor attributes {_xla_compile_device_type = "CPU", allow_soft_placement = true} { -// CHECK: %0 = "tf.StatefulPartitionedCall"(%arg0) <{config = "", config_proto = "", executor_type = "", f = @f1}> {_xla_compile_device_type = "CPU"} : (tensor) -> tensor -// CHECK: %cst = "tf.Const"() <{value = dense<5> : tensor}> : () -> tensor -// CHECK: %1 = "tf.Add"(%0, %cst) : (tensor, tensor) -> tensor -// CHECK: return %1 : tensor -// CHECK: } - -// CHECK-LABEL: func.func private @entry2_outlined(%arg0: tensor) -> tensor attributes {_xla_compile_device_type = "CPU", allow_soft_placement = true} { -// CHECK: %0 = "tf.StatefulPartitionedCall"(%arg0) <{config = "", config_proto = "", executor_type = "", f = @f1}> {_xla_compile_device_type = "CPU"} : (tensor) -> tensor -// CHECK: %cst = "tf.Const"() <{value = dense<5> : tensor}> : () -> tensor -// CHECK: %1 = "tf.Add"(%0, %cst) : (tensor, tensor) -> tensor -// CHECK: return %1 : tensor -// CHECK: } - -// CHECK: func.func @entry1(%arg0: tensor) -> tensor attributes {_xla_compile_device_type = "CPU", tf.entry_function = {}} { -// CHECK: %0 = "tf.StatefulPartitionedCall"(%arg0) <{config = "", config_proto = "", executor_type = "", f = @entry1_outlined}> {_xla_compile_device_type = "CPU", allow_soft_placement = true} : (tensor) -> tensor -// CHECK: return %0 : tensor -// CHECK: } - -// CHECK: func.func @entry2(%arg0: tensor) -> tensor attributes {_xla_compile_device_type = "CPU", tf.entry_function = {}} { -// CHECK: %0 = "tf.StatefulPartitionedCall"(%arg0) <{config = "", config_proto = "", executor_type = "", f = @entry2_outlined}> {_xla_compile_device_type = "CPU", allow_soft_placement = true} : (tensor) -> tensor -// CHECK: return %0 : tensor -// CHECK: } -func.func @entry1(%arg0: tensor) -> tensor attributes {_xla_compile_device_type = "CPU", allow_soft_placement = true, tf.entry_function = {}} { - %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", _xla_compile_device_type = "CPU", f = @f1} : (tensor) -> (tensor) - %1 = "tf.Const"() {value = dense<5> : tensor} : () -> tensor - %2 = "tf.Add"(%0, %1) : (tensor, tensor) -> (tensor) - func.return %2 : tensor -} - -func.func @entry2(%arg0: tensor) -> tensor attributes {_xla_compile_device_type = "CPU", allow_soft_placement = true, tf.entry_function = {}} { - %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", _xla_compile_device_type = "CPU", f = @f1} : (tensor) -> (tensor) - %1 = "tf.Const"() {value = dense<5> : tensor} : () -> tensor - %2 = "tf.Add"(%0, %1) : (tensor, tensor) -> (tensor) - func.return %2 : tensor -} - -func.func @f1(%arg0: tensor) -> tensor { - func.return %arg0 : tensor -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite_v2.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite_v2.mlir deleted file mode 100644 index e36bdaa72e41b8..00000000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite_v2.mlir +++ /dev/null @@ -1,110 +0,0 @@ -// RUN: tf-opt %s -split-input-file -tf-xla-rewrite-v2 | FileCheck %s - - -module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:GPU:0"]} { - // CHECK-LABEL: func.func @convert_cluster_func - func.func @convert_cluster_func(%arg0: tensor) -> tensor { - // CHECK: "tf_device.launch"() - // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}> - // CHECK: "tf._XlaCompile"(%arg0) <{function = @func, must_compile = true, operandSegmentSizes = array}> : (tensor) -> (tensor<3x!tf_type.string>, tensor) - // CHECK: "tf_device.launch"() - // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}> - // CHECK: "tf._XlaRun"(%arg0, %0#0) : (tensor, tensor<3x!tf_type.string>) -> tensor - %0 = "tf_device.cluster_func"(%arg0) {func = @func, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor) -> tensor - func.return %0 : tensor - } - - func.func @func(%arg0: tensor) -> tensor { - func.return %arg0 : tensor - } -} - -// ----- - -module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:GPU:0"]} { - // CHECK-LABEL: func.func @convert_cluster_func_with_resources_in_order - func.func @convert_cluster_func_with_resources_in_order(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: "tf_device.launch"() - // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}> - // CHECK: "tf._XlaCompile"(%arg1, %arg0) <{function = @func_with_resources_in_order, must_compile = true, operandSegmentSizes = array}> : (tensor, tensor) - // CHECK: "tf_device.launch"() - // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}> - // CHECK: "tf._XlaRun"(%arg1, %arg0, %0#0) : (tensor, tensor, tensor<3x!tf_type.string>) -> tensor - %0 = "tf_device.cluster_func"(%arg1, %arg0) {func = @func_with_resources_in_order, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor, tensor) -> (tensor) - func.return %0 : tensor - } - - func.func @func_with_resources_in_order(%arg0 : tensor, %arg1 : tensor) -> tensor { - func.return %arg0 : tensor - } -} - -// ----- - -module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:GPU:0"]} { - // CHECK-LABEL: func.func @convert_cluster_func_with_resources - func.func @convert_cluster_func_with_resources(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: "tf_device.launch"() - // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}> - // CHECK: "tf._XlaCompile"(%arg1, %arg0) <{function = @func_with_resources_1, must_compile = true, operandSegmentSizes = array}> : (tensor, tensor) -> (tensor<3x!tf_type.string>, tensor) - // CHECK: "tf_device.launch"() - // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}> - // CHECK: "tf._XlaRun"(%arg1, %arg0, %0#0) : (tensor, tensor, tensor<3x!tf_type.string>) -> tensor - %0 = "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_resources_1, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor, tensor) -> tensor - // CHECK: "tf_device.launch"() - // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}> - // CHECK: "tf._XlaCompile"(%arg1, %arg0) <{function = @func_with_resources_2, must_compile = true, operandSegmentSizes = array}> : (tensor, tensor) -> (tensor<3x!tf_type.string>, tensor) - // CHECK: "tf_device.launch"() - // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}> - // CHECK: "tf._XlaRun"(%arg1, %arg0, %2#0) : (tensor, tensor, tensor<3x!tf_type.string>) -> tensor - %1 = "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_resources_2, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor, tensor) -> tensor - return %0 : tensor - } - - - func.func @func_with_resources_1(%arg0 : tensor, %arg1: tensor) -> tensor { - func.return %arg1 : tensor - } - - func.func @func_with_resources_2(%arg0 : tensor, %arg1: tensor) -> tensor { - func.return %arg1 : tensor - } -} - -// ----- - -// CHECK-LABEL: func.func @outside_compilation_in_generic_pipeline -module attributes {tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0"], tf.versions = {producer = 888 : i32}} { - func.func @outside_compilation_in_generic_pipeline(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK: tf_device.launch - // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}> - // CHECK: "tf._XlaCompile"() <{function = @func, must_compile = true, operandSegmentSizes = array}> - // CHECK: tf_device.parallel_execute - // CHECK: tf_device.launch - // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> - // CHECK: tf.B - // CHECK: tf._XlaSendFromHost - // CHECK: tf_device.launch - // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}> - // CHECK: tf._XlaRun - %0 = "tf_device.parallel_execute"() ({ - "tf_device.launch"() ({ - %1 = "tf._XlaCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf_type.string> - %2 = "tf.B"() : () -> tensor<2xi32> - "tf._XlaSendFromHost"(%2, %1) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_0_retvals"} : (tensor<2xi32>, tensor<3x!tf_type.string>) -> () - tf_device.return - }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> () - tf_device.return - }, { - %0 = "tf_device.cluster_func"() {func = @func, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : () -> tensor<2xi32> - tf_device.return %0 : tensor<2xi32> - }) : () -> tensor<2xi32> - return %0 : tensor<2xi32> - } - func.func @func() -> tensor<2xi32> { - %2 = "tf.A"() : () -> tensor<2xi32> - %3 = "tf._XlaHostComputeMlir"() {host_mlir_module = "", manual_sharding = false, recv_key = "host_compute_channel_0_retvals", send_key = "host_compute_channel_0_args"} : () -> tensor<2xi32> - %4 = "tf.C"(%3) : (tensor<2xi32>) -> tensor<2xi32> - func.return %4 : tensor<2xi32> - } -} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index 3d1cf1bd58fa38..2e090224a5c86c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -541,7 +541,6 @@ cc_library( "xla_call_module_serialization.cc", "xla_inline_device_ops.cc", "xla_rewrite.cc", - "xla_rewrite_v2.cc", "xla_validate_inputs.cc", ], hdrs = [ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 84220aa346bf50..880dfa837e881c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -122,10 +122,6 @@ LogicalResult ConstantFoldFallbackHook( inputs.push_back(input.cast()); } - // Avoid overlapping folds with the same context. - // TODO(jpienaar): Avoid using global context & mutex here. - static auto* mu = new tensorflow::mutex(); - tensorflow::mutex_lock l(*mu); SmallVector constants; LogicalResult status = EvaluateOperation(inst, inputs, constants); results.assign(constants.begin(), constants.end()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc index a239c7304a0ae0..c61b1e0c14a852 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc @@ -94,12 +94,7 @@ void AddTPULowerClusterToRuntimeOpsPassPipeline(OpPassManager& pm, void AddNonTPULowerClusterToRuntimeOpsPassPipeline( OpPassManager& pm, llvm::StringRef module_name) { // Rewrite cluster functions into XLA launch ops. - if (tensorflow::GetMlirCommonFlags() - ->tf_mlir_enable_generic_outside_compilation) { - pm.addPass(mlir::TFDevice::CreateXlaRewriteV2Pass()); - } else { - pm.addPass(mlir::TFDevice::CreateXlaRewritePass()); - } + pm.addPass(mlir::TFDevice::CreateXlaRewritePass()); // Re-run the canonicalizer pass as some cleanup during resource op lifting // pass opens up some opportunities for canonicalization of cluster ops. // Specifically, we want to eliminate pass through results from the cluster diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index da89e77cb0862c..3a2ba6f181f649 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -483,10 +483,6 @@ std::unique_ptr> CreateXlaInlineDeviceOpsPass(); // type` with `tf.XlaLaunch` ops. std::unique_ptr> CreateXlaRewritePass(); -// Creates a pass that rewrites partitioned calls with `tf._XlaCompile` op and -// `tf.XlaRun` op. -std::unique_ptr> CreateXlaRewriteV2Pass(); - // Create a pass that validates the input graph to the CPU/GPU bridge. std::unique_ptr> CreateXlaValidateInputsPass(); } // namespace TFDevice diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc index 0c450126e4e090..e565d50660558c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc @@ -193,14 +193,15 @@ struct EmbeddingPipeliningPass bool UseEmbeddingPipelining(ModuleOp& module) { // Enable automated pipelining pass unless: - // 1. The user disables it via flog, or + // 1. The user disables it via flag, or // 2. The graph contains TF.Summary ops. Graphs like this typically only run // for a single step which doesn't work in pipelining. if (tensorflow::GetBuildXlaOpsPassFlags() - ->tf_xla_disable_full_embedding_pipelining) + ->tf_xla_disable_full_embedding_pipelining) { + LOG(INFO) << "Embedding pipelining disabled via flag."; return false; - + } // Detect summaries by looking for key Ops in the graph. It would be better to // do this via operator attributes rather than looking for a specific op. WalkResult walk_result = module.walk([&](Operation* op) -> WalkResult { @@ -208,10 +209,10 @@ bool UseEmbeddingPipelining(ModuleOp& module) { return WalkResult::advance(); }); if (walk_result.wasInterrupted()) { - VLOG(1) << "TF summaries detected - disabling embedding pipelining."; + LOG(INFO) << "TF summaries detected - disabling embedding pipelining."; return false; } - VLOG(1) << "Embedding pipelining rewrite enabled."; + LOG(INFO) << "Embedding pipelining rewrite enabled."; return true; } @@ -1685,12 +1686,11 @@ Operation* LiftNonTpuFuncCaller(mlir::OpBuilder& builder, } void EmbeddingPipeliningPass::runOnOperation() { - VLOG(3) << "EmbeddingPipeliningPass::runOnOperation()"; + LOG(INFO) << "EmbeddingPipeliningPass::runOnOperation()"; ModuleOp module = getOperation(); // We only use one of the EmbeddingPipelining and EmbeddingSequencing passes. if (!UseEmbeddingPipelining(module)) return; - VLOG(1) << "Embedding pipelining rewrite enabled."; SymbolTable symbol_table(module); @@ -1722,7 +1722,7 @@ void EmbeddingPipeliningPass::runOnOperation() { // If there are no forward pass ops, there is no SC, so we end early. if (forward_pass_ops.empty()) { if (backward_pass_ops.empty()) { - VLOG(1) << "no pipelining ops found"; + LOG(INFO) << "no pipelining ops found"; return; } else { (*backward_pass_ops.begin())->emitOpError() @@ -1812,11 +1812,11 @@ void EmbeddingPipeliningPass::runOnOperation() { if (failed(result)) return signalPassFailure(); merged_set.insert(non_tpu_ops.begin(), non_tpu_ops.end()); - VLOG(3) << "Forwards pass " << forward_pass_ops.size() - << " ops, backwards pass " << backward_pass_ops.size() - << " ops, core " << core_tpu_ops.size() - << " ops. Total = " << merged_set.size() << " of " - << GetNumOps(loop_body_func); + LOG(INFO) << "Forwards pass " << forward_pass_ops.size() + << " ops, backwards pass " << backward_pass_ops.size() + << " ops, core " << core_tpu_ops.size() + << " ops. Total = " << merged_set.size() << " of " + << GetNumOps(loop_body_func); builder.setInsertionPointAfter(*non_tpu_ops.begin()); TF::StatefulPartitionedCallOp non_tpu_caller = nullptr; @@ -2185,7 +2185,8 @@ void EmbeddingPipeliningPass::runOnOperation() { int parallel_iterations = parallel_iterations_flag > 0 ? parallel_iterations_flag : orig_while_op.getParallelIterations(); - VLOG(1) << "Setting parallel_iterations_flag to " << parallel_iterations_flag; + LOG(INFO) << "Setting parallel_iterations_flag to " + << parallel_iterations_flag; auto new_while_op = builder.create( orig_while_op->getLoc(), new_body_return_types, new_while_operands.getArrayRef(), cond.getSymName(), body.getSymName(), @@ -2252,7 +2253,7 @@ void EmbeddingPipeliningPass::runOnOperation() { orig_while_op.body_function().erase(); orig_while_op.erase(); - VLOG(3) << "EmbeddingPipeliningPass::runOnOperation done."; + LOG(INFO) << "EmbeddingPipeliningPass::runOnOperation done."; } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc index 7ed29a3ed58cc3..577b374a43847d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc @@ -413,8 +413,7 @@ LogicalResult FindForwardPassOps(OpBuilder& builder, if (is_non_variable && is_variable) { loop_body_func.emitOpError() << "resource input " << argument.getArgNumber() - << " is used both as a varible and not " - << " a variable"; + << " is used both as a varible and not a variable"; return LogicalResult::failure(); } if (is_variable && use_in_forward) @@ -772,7 +771,7 @@ LogicalResult ExtractOpsAsFunc( } void EmbeddingSequencingPass::runOnOperation() { - VLOG(3) << "EmbeddingSequencingPass::runOnOperation()"; + LOG(INFO) << "EmbeddingSequencingPass::runOnOperation()"; ModuleOp module = getOperation(); llvm::SetVector forward_pass_ops; @@ -803,6 +802,8 @@ void EmbeddingSequencingPass::runOnOperation() { // If there are no forward pass ops, there is no SC, so we end early. if (forward_pass_ops.empty()) { if (backward_pass_ops.empty()) { + LOG(INFO) << "No unprocessed embedding ops found - skipping embedding " + << "sequencing rewrite."; return; } else { (*backward_pass_ops.begin())->emitOpError() @@ -810,7 +811,7 @@ void EmbeddingSequencingPass::runOnOperation() { return signalPassFailure(); } } - VLOG(1) << "Embedding sequencing rewrite enabled."; + LOG(INFO) << "Embedding sequencing rewrite enabled."; // Ensure that all ops are in the same region, and have the same replication // info. @@ -860,18 +861,17 @@ void EmbeddingSequencingPass::runOnOperation() { TF::WhileOp while_op = nullptr; result = FindOwningWhileOp(loop_body_func, module, &while_op); if (failed(result)) { - VLOG(1) << "WhileOp not found: assuming external loop."; + LOG(INFO) << "WhileOp not found: assuming external loop."; } else { // Override the WhileOp parallel_iterations if requested by flag. int parallel_iterations_flag = tensorflow::GetBuildXlaOpsPassFlags() ->tf_xla_embedding_parallel_iterations; if (parallel_iterations_flag > 0) { - VLOG(1) << "Setting WhileOp parallel_iterations_flag to " - << parallel_iterations_flag; + LOG(INFO) << "Setting WhileOp parallel_iterations_flag to " + << parallel_iterations_flag; while_op.setParallelIterations(parallel_iterations_flag); } else { - VLOG(1) << "Using original WhileOp parallel_iterations = " - << while_op.getParallelIterations(); + LOG(INFO) << "Using original WhileOp parallel_iteration"; } } @@ -898,11 +898,11 @@ void EmbeddingSequencingPass::runOnOperation() { if (failed(result)) return signalPassFailure(); merged_set.insert(non_tpu_ops.begin(), non_tpu_ops.end()); - VLOG(2) << "Forwards pass " << forward_pass_ops.size() - << " ops, backwards pass " << backward_pass_ops.size() - << " ops, core " << core_tpu_ops.size() - << " ops. Total = " << merged_set.size() << " of " - << GetNumOps(loop_body_func) << ".\n"; + LOG(INFO) << "Forwards pass " << forward_pass_ops.size() + << " ops, backwards pass " << backward_pass_ops.size() + << " ops, core " << core_tpu_ops.size() + << " ops. Total = " << merged_set.size() << " of " + << GetNumOps(loop_body_func) << ".\n"; builder.setInsertionPointAfter(*non_tpu_ops.begin()); Operation* non_tpu_caller = nullptr; @@ -936,7 +936,7 @@ void EmbeddingSequencingPass::runOnOperation() { metadata_op->erase(); compilation_op->erase(); - VLOG(3) << "EmbeddingSequencingPass::runOnOperation done."; + LOG(INFO) << "EmbeddingSequencingPass::runOnOperation done."; } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td index c89c909375df67..169f5f206dabc5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td @@ -330,54 +330,3 @@ def XlaValidateInputsPass : Pass<"tf-xla-validate-inputs", "ModuleOp"> { let constructor = "TFDevice::CreateXlaValidateInputsPass()"; let dependentDialects = ["tf_device::TensorFlowDeviceDialect"]; } - -def XlaRewriteV2Pass : Pass<"tf-xla-rewrite-v2", "mlir::ModuleOp"> { - let summary = "Rewrites `tf_device.cluster_func op` into `_XlaCompile` and `_XlaRun` ops to make the attached function run on XLA."; - - let description = [{ - This pass rewrites `tf_device.cluster_func` op into - `tf._XlaCompile` op and `tf._XlaRun` op. This makes the attached - function execute with XLA. `tf.XlaCompile` requires resource-type arguments - come at the end, so this pass rewrites the called function if necessary. - This pass assumes there are no nested `tf_device.cluster`s so we don't end - up creating nested `tf._XlaCompile` and `tf._XlaRun` ops. - - For example, the `tf_device.cluster_func` operation in the following code - - ```mlir - func.func @convert_cluster_func_with_resources(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf_device.cluster_func"(%arg0, %arg1) {device = "/job:localhost/replica:0/task:0/device:GPU:0", func = @func_with_resources} : (tensor, tensor) -> tensor - return %0 : tensor - } - - func.func @func_with_resources(%arg0: tensor, %arg1: tensor) -> tensor { - return %arg1 : tensor - } - ``` - - will be replaced by a `tf._XlaCompile` and `tf._XlaRun` operation. - - ```mlir - func.func @convert_cluster_func_with_resources(%arg0: tensor, %arg1: tensor) -> tensor { - %0:2 = "tf_device.launch"() ({ - %key, %compilation_successful = "tf._XlaCompile"(%arg1, %arg0) {function = @func_with_resources, must_compile = true, operand_segment_sizes = array} : (tensor, tensor) -> (tensor<3x!tf_type.string>, tensor) - tf_device.return %key, %compilation_successful : tensor<3x!tf_type.string>, tensor - }) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : () -> (tensor<3x!tf_type.string>, tensor) - %1 = "tf_device.launch"() ({ - %2 = "tf._XlaRun"(%arg1, %arg0, %0#0) : (tensor, tensor, tensor<3x!tf_type.string>) -> tensor - tf_device.return %2 : tensor - }) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : () -> tensor - return %1 : tensor - } - - func.func @func_with_resources(%arg0: tensor, %arg1: tensor) -> tensor { - return %arg0 : tensor - } - ``` - Notice that the called function is rewritten, with the order of its parameters changed. - }]; - - let constructor = "TFDevice::CreateXlaRewriteV2Pass()"; - let dependentDialects = ["tf_device::TensorFlowDeviceDialect"]; -} - diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite_v2.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite_v2.cc deleted file mode 100644 index f8752e316233dd..00000000000000 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite_v2.cc +++ /dev/null @@ -1,397 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This transformation pass converts tf_device.cluster_func op into -// tf._XlaCompile and tf._XlaRun ops. - -#include -#include - -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/parallel_execute_util.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.h" - -#define DEBUG_TYPE "tf-xla-rewrite-v2" - -namespace mlir { -namespace { - -#define GEN_PASS_DEF_XLAREWRITEV2PASS -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.h.inc" - -constexpr absl::string_view kDeviceAttr = "device"; - -struct XlaRewriteV2Pass : public impl::XlaRewriteV2PassBase { - void runOnOperation() override; -}; - -// Get the device from `tf_device.cluster_func` op -mlir::LogicalResult GetClusterFuncDevice(tf_device::ClusterFuncOp cluster_func, - std::string& compilation_device) { - auto device_attr = cluster_func->getAttrOfType(kDeviceAttr); - if (device_attr) { - compilation_device = device_attr.str(); - } else { - return cluster_func.emitOpError("No device assigned for cluster_func "); - } - return success(); -} - -// Rearrange the input order by putting resource args after non resource args -// Returns true when the inputs is in order, otherwise return false -bool RearrangeInputOrder(llvm::SmallVector inputs, - llvm::SmallVector& non_resource_args, - llvm::SmallVector& resource_args) { - bool has_resources = false; - bool in_order = true; - for (const Value& arg : inputs) { - if (!getElementTypeOrSelf(arg.getType()).template isa()) { - non_resource_args.push_back(arg); - if (has_resources) in_order = false; - } else { - resource_args.push_back(arg); - has_resources = true; - } - } - return in_order; -} - -// Move the resource args to the end of the function operand list. -void MoveResourceArgsToEnd(func::FuncOp callee) { - llvm::DenseMap mapping; - unsigned num_params = callee.getNumArguments(); - llvm::BitVector removed_params(num_params); - // Copy the resource-type parameters to the end. - for (unsigned i = 0; i < num_params; ++i) { - BlockArgument param = callee.getArgument(i); - if (getElementTypeOrSelf(param.getType()) - .template isa()) { - removed_params.set(i); - callee.getBody().addArgument(param.getType(), param.getLoc()); - param.replaceAllUsesWith(callee.getArguments().back()); - removed_params.push_back(false); - } - } - // Remove old resource-type parameters. - callee.getBody().front().eraseArguments(removed_params); - // Update function type. - callee.setFunctionType(FunctionType::get(callee.getContext(), - callee.getBody().getArgumentTypes(), - callee.getResultTypes())); -} - -mlir::LogicalResult GetOutputTypesForClusterFunc( - mlir::tf_device::ClusterFuncOp cluster_func, - llvm::SmallVectorImpl* output_types) { - output_types->reserve(cluster_func.getNumResults()); - for (const auto& result_and_index : - llvm::enumerate(cluster_func.getResults())) { - const auto cluster_func_output_type = - result_and_index.value().getType().cast(); - output_types->emplace_back(cluster_func_output_type); - } - return mlir::success(); -} - -mlir::LogicalResult ExtractInputsForLogicalDevices( - const int num_cores_per_replica, - mlir::tf_device::ClusterFuncOp cluster_func, mlir::OpBuilder* builder, - llvm::SmallVectorImpl>* input_list) { - // Initialize the input list for each logical devices. - input_list->reserve(num_cores_per_replica); - for (int i = 0; i < num_cores_per_replica; ++i) - input_list->emplace_back(llvm::SmallVector()); - - llvm::SmallVector cluster_func_inputs( - cluster_func.getOperands()); - - // If sharding attribute does not exist, then all inputs are placed on 0th - // logical core by default. - (*input_list)[0] = cluster_func_inputs; - return mlir::success(); -} - -// Creates a `tf._XlaRun` op that executes XLA program. -LogicalResult BuildExecuteOp(llvm::SmallVector input, - tf_device::ClusterFuncOp cluster_func, - Operation* compile_op, int core, - OpBuilder* builder, TF::_XlaRunOp* execute_op) { - llvm::SmallVector output_types; - llvm::SmallVector cluster_to_core_index; - auto result = GetOutputTypesForClusterFunc(cluster_func, &output_types); - if (failed(result)) return failure(); - - llvm::SmallVector non_resource_args, resource_args; - bool in_order = RearrangeInputOrder(input, non_resource_args, resource_args); - - llvm::SmallVector execute_inputs; - if (!in_order) { - for (auto non_resource_arg : non_resource_args) { - execute_inputs.emplace_back(non_resource_arg); - } - for (auto resource_arg : resource_args) { - execute_inputs.emplace_back(resource_arg); - } - } else { - execute_inputs = input; - } - execute_inputs.emplace_back(compile_op->getResult(core)); - - // _XlaRun op has same output types as cluster_func. - *execute_op = builder->create(cluster_func.getLoc(), - output_types, execute_inputs); - return success(); -} - -// parallel_execute op returns concatenated list of return values of all its -// regions. -mlir::LogicalResult GetConcatenatedOutputTypes( - const int num_cores_per_replica, tf_device::ClusterFuncOp cluster_func, - tf_device::ParallelExecuteOp old_parallel_execute, - const ValueTypeRange& cluster_result_types, - llvm::SmallVector& concatenated_output_types) { - // parallel_execute op returns concatenated list of return values of - // all its regions. - concatenated_output_types.reserve(cluster_result_types.size() * - num_cores_per_replica); - for (mlir::Region& region : old_parallel_execute.getRegions()) { - if (!isa(region.front().front())) { - for (Type t : region.front().front().getResultTypes()) - concatenated_output_types.emplace_back(t); - } - } - - for (int core = 0; core < num_cores_per_replica; ++core) { - llvm::SmallVector output_types; - auto result = GetOutputTypesForClusterFunc(cluster_func, &output_types); - if (failed(result)) return failure(); - for (Type t : output_types) { - concatenated_output_types.emplace_back(t); - } - } - return success(); -} - -// Given a `ParallelExecute`, replace it with a new `ParallelExecute`. The -// new `ParallelExecute` will replace the child that contains the -// `ClusterFunc` with `num_cores_per_replica` children. It keep other children -// the same. Return values from the child with the `ClusterFunc` will be -// duplicated `num_cores_per_replica` times. -LogicalResult AddToParallelExecuteOp( - llvm::SmallVectorImpl>* cluster_to_core_index, - Operation* compile_op, tf_device::ClusterFuncOp cluster_func, - OpBuilder* builder, tf_device::ParallelExecuteOp old_parallel_execute, - tf_device::ParallelExecuteOp* new_parallel_execute, int* cluster_idx) { - const int num_cores_per_replica = 1; - const auto cluster_result_types = cluster_func.getResultTypes(); - llvm::SmallVector concatenated_output_types; - - if (failed(GetConcatenatedOutputTypes( - num_cores_per_replica, cluster_func, old_parallel_execute, - cluster_result_types, concatenated_output_types))) - return failure(); - - *cluster_idx = tensorflow::MovePreservedParallelExecuteChildren( - num_cores_per_replica, concatenated_output_types, builder, cluster_func, - old_parallel_execute, new_parallel_execute); - - // Extract inputs for each block of the parallel_execute op. The i-th - // element in the list represents the input lists to XLA computation for - // i-th logical core. - llvm::SmallVector, 4> input_list; - builder->setInsertionPoint(*new_parallel_execute); - auto result = ExtractInputsForLogicalDevices( - num_cores_per_replica, cluster_func, builder, &input_list); - if (failed(result)) return failure(); - - // For each logical core, create a region with tf._XlaRun op. - for (int core = 0; core < num_cores_per_replica; ++core) { - auto& block = - new_parallel_execute->GetRegionBlockWithIndex((*cluster_idx) + core); - builder->setInsertionPointToEnd(&block); - - // Create Execute op _XlaRun. - TF::_XlaRunOp execute; - if (failed(BuildExecuteOp(input_list[core], cluster_func, compile_op, core, - builder, &execute))) - return failure(); - - std::string execute_device; - if (failed(GetClusterFuncDevice(cluster_func, execute_device))) - return failure(); - - auto block_launch_op = tensorflow::WrapOpInLaunch( - builder, block.getParent()->getLoc(), execute, execute_device); - - builder->create(block.getParent()->getLoc(), - block_launch_op.getResults()); - } - - return success(); -} - -// Replace the uses of old parallel execute outputs with new outputs -mlir::LogicalResult RemapOutputsFromLogicalDevices( - mlir::tf_device::ParallelExecuteOp old_parallel_execute, int cluster_idx, - mlir::tf_device::ParallelExecuteOp new_parallel_execute, - mlir::OpBuilder* builder) { - for (auto [output_index, old_parallel_execute_output] : - llvm::enumerate(old_parallel_execute.getResults())) { - const auto output_from_logical_device = - new_parallel_execute.GetRegionOutputs(cluster_idx)[output_index]; - old_parallel_execute_output.replaceAllUsesWith(output_from_logical_device); - } - return mlir::success(); -} - -// Create a `tf._XlaCompile` op -Operation* BuildCompileOp(tf_device::ClusterFuncOp cluster_func, - llvm::StringRef compilation_device, - SymbolTable& symtab, OpBuilder* builder) { - llvm::SmallVector non_resource_args, resource_args; - bool in_order = RearrangeInputOrder(cluster_func.getOperands(), - non_resource_args, resource_args); - if (!in_order) { - // Functions do not get reused in practice, so skip the check for if the - // callee has been updated. - StringAttr callee_sym = cluster_func.getFuncAttr().getAttr(); - MoveResourceArgsToEnd(symtab.lookup(callee_sym)); - } - - auto program_type = - RankedTensorType::get({3}, builder->getType()); - auto compilation_status_type = - RankedTensorType::get({}, builder->getType()); - auto compile_op = builder->create( - cluster_func.getLoc(), program_type, compilation_status_type, - /*constants=*/ValueRange({}), ValueRange(non_resource_args), - ValueRange(resource_args), builder->getBoolAttr(true), - cluster_func.getFuncAttr()); - return tensorflow::WrapOpInLaunch(builder, compile_op.getLoc(), compile_op, - compilation_device); -} - -mlir::LogicalResult GetCompilationDeviceFromParallelExecuteOp( - tf_device::ParallelExecuteOp& old_parallel_execute, - std::string& compilation_device) { - auto& first_block = old_parallel_execute.GetRegionBlockWithIndex(0); - if (isa(first_block.front())) { - auto device_attr = - first_block.front().getAttrOfType(kDeviceAttr); - if (device_attr) { - compilation_device = device_attr.str(); - } else { - return failure(); - } - } - return success(); -} - -mlir::LogicalResult Rewrite(tf_device::ClusterFuncOp cluster_func, - SymbolTable& symtab, OpBuilder& builder) { - // Fetch the ParallelExecute parent of `cluster_func`, or create it if - // it does not exist. - tf_device::ParallelExecuteOp old_parallel_execute = - cluster_func->getParentOfType(); - if (old_parallel_execute && - cluster_func->getParentOp() != old_parallel_execute) { - cluster_func->emitError() << "The ParallelExecute ancestor of a " - "ClusterFunc must be its direct parent."; - } - - // Fetch compilation device - std::string compilation_device; - if (failed(GetClusterFuncDevice(cluster_func, compilation_device))) - return failure(); - - if (!old_parallel_execute) { - old_parallel_execute = - mlir::TF::BuildParallelExecuteOp(cluster_func, &builder); - } - - // Build compile op _XlaCompile - builder.setInsertionPoint(old_parallel_execute); - Operation* compile_op = - BuildCompileOp(cluster_func, compilation_device, symtab, &builder); - if (!compile_op) { - return failure(); - } - - old_parallel_execute.walk( - [&](TF::_XlaCompileMlirPlaceholderProgramKeyOp key_op) { - key_op.replaceAllUsesWith(compile_op->getResult(0)); - key_op.erase(); - }); - - // Build new parallel execute op - tf_device::ParallelExecuteOp new_parallel_execute; - int num_cores_per_replica = 1; - int cluster_idx; - llvm::SmallVector, 4> cluster_to_core_index; - cluster_to_core_index.reserve(num_cores_per_replica); - - if (failed(AddToParallelExecuteOp( - &cluster_to_core_index, compile_op, cluster_func, &builder, - old_parallel_execute, &new_parallel_execute, &cluster_idx))) - return failure(); - - // As tf_device.parallel_execute wraps # logical cores number of tf._XlaRun - // ops, the number of return values of parallel_execute op may exceed that of - // cluster_func op. As such, each return value of parallel_execute op must - // be mapped with corresponding return value usages of cluster_func. - if (failed(RemapOutputsFromLogicalDevices(old_parallel_execute, cluster_idx, - new_parallel_execute, &builder))) - return failure(); - - if (failed(mlir::TF::RemoveSingletonParallelExecuteOp(new_parallel_execute, - &builder))) - return failure(); - - return success(); -} - -void XlaRewriteV2Pass::runOnOperation() { - ModuleOp module = getOperation(); - SymbolTable symtab(module); - OpBuilder builder(&getContext()); - llvm::SmallVector cluster_func_ops; - module.walk([&](tf_device::ClusterFuncOp cluster_func) { - cluster_func_ops.push_back(cluster_func); - }); - - for (tf_device::ClusterFuncOp cluster_func : cluster_func_ops) { - if (failed(Rewrite(cluster_func, symtab, builder))) - return signalPassFailure(); - } - - // Erase all the tf_device.cluster_func ops - if (failed(tensorflow::EraseClusterFuncs(cluster_func_ops))) { - return signalPassFailure(); - } -} - -} // namespace - -namespace TFDevice { -std::unique_ptr> CreateXlaRewriteV2Pass() { - return std::make_unique(); -} - -} // namespace TFDevice -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 95a637dbfbb3b3..42b059cbd0a527 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -4280,17 +4280,13 @@ SavedModelMLIRImportInput::~SavedModelMLIRImportInput() {} absl::StatusOr> ConvertGraphdefToMlir( const GraphDef& graphdef, const GraphDebugInfo& debug_info, - const GraphImportConfig& specs, mlir::MLIRContext* context, - bool add_default_attributes) { + const GraphImportConfig& specs, mlir::MLIRContext* context) { GraphConstructorOptions options; options.allow_internal_ops = true; - options.add_default_attributes = add_default_attributes; Graph graph(OpRegistry::Global()); - GraphDef preprocessed_graphdef(graphdef); - if (add_default_attributes) { - TF_RETURN_IF_ERROR(PreprocessGraphDef(&specs, &preprocessed_graphdef)); - } + TF_RETURN_IF_ERROR(PreprocessGraphDef(&specs, &preprocessed_graphdef)); + if (specs.upgrade_legacy) { TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty( preprocessed_graphdef, graph.flib_def().default_registry())); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index 182a53078ba215..1670fd11a1f819 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -41,8 +41,7 @@ inline constexpr absl::string_view kImportModelDefaultGraphFuncName = "main"; // tf_executor dialect. tsl::StatusOr> ConvertGraphdefToMlir( const GraphDef& graphdef, const GraphDebugInfo& debug_info, - const GraphImportConfig& specs, mlir::MLIRContext* context, - bool add_default_attributes = true); + const GraphImportConfig& specs, mlir::MLIRContext* context); // Given a Graph, returns a MLIR module containing the graph, expressed with // tf_executor dialect. diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc index e674989d2174ba..2dda3809fc2b9b 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc @@ -60,16 +60,16 @@ class SessionClusterTensorflowDialectTest : public ::testing::Test { context_.loadAllAvailableDialects(); } - tsl::Status CreateMlirModule(std::string mlir_module_filename) { + absl::Status CreateMlirModule(std::string mlir_module_filename) { std::string mlir_module_path = TestDataPath() + mlir_module_filename; mlir_module_ = mlir::parseSourceFile(mlir_module_path, &context_); if (!mlir_module_) { - return tsl::Status( + return absl::Status( absl::StatusCode::kNotFound, absl::StrCat("Could not find MLIR module at ", mlir_module_path)); } - return tsl::OkStatus(); + return absl::OkStatus(); } DialectRegistry registry_; diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc index 59fb22e87eab58..322862828e63b3 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc @@ -110,7 +110,7 @@ Status MaybeRewriteLayoutWithShardedShape( mlir::StringAttr sharding, const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, xla::Shape* shape) { - if (!sharding) return OkStatus(); + if (!sharding) return absl::OkStatus(); xla::OpSharding op_sharding; if (tensorflow::DecodeShardingAttribute(sharding, op_sharding).failed()) { @@ -121,7 +121,7 @@ Status MaybeRewriteLayoutWithShardedShape( TF_ASSIGN_OR_RETURN(hlo_sharding, xla::HloSharding::FromProto(op_sharding)); TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( hlo_sharding, /*use_fast_memory=*/false, shape_determination_fns, shape)); - return OkStatus(); + return absl::OkStatus(); } // Converts arg_shapes to xla::Shape's and store into xla_input_shapes. @@ -168,7 +168,7 @@ Status GetXlaInputShapes( } else { *xla_input_shapes = individual_arg_shapes; } - return OkStatus(); + return absl::OkStatus(); } // Returns a static ranked tensor type corresponding to the given static or @@ -307,7 +307,7 @@ Status GetOutputInfo( // XLA computation always uses Tuple shape. *xla_output_shape = xla::ShapeUtil::MakeTupleShape(shapes); - return OkStatus(); + return absl::OkStatus(); } // Creates a vector that maps from the parameters of the XLA computation to @@ -666,7 +666,7 @@ Status ConvertMLIRWithOptionalXlaComputation( module_op, &hlo_proto, use_tuple_args, return_tuple, options)); *xla_computation = xla::XlaComputation(hlo_proto.hlo_module()); } - return OkStatus(); + return absl::OkStatus(); } // Wraps the optional lowering version to keep the api the same for clients. @@ -692,7 +692,7 @@ Status CompileMlirSetup(mlir::ModuleOp module_op, if (VLOG_IS_ON(2)) tensorflow::DumpMlirOpToFile("compile_mlir_shape_refiner", module_op); - return OkStatus(); + return absl::OkStatus(); } Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder, @@ -715,7 +715,7 @@ Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder, if (VLOG_IS_ON(2)) tensorflow::DumpMlirOpToFile("build_hlo_tf_after", module_op); - return OkStatus(); + return absl::OkStatus(); } Status PopulateCollectiveInfo(mlir::ModuleOp module_op, @@ -729,7 +729,7 @@ Status PopulateCollectiveInfo(mlir::ModuleOp module_op, kGroupSizeAttrName.data(), kGroupSizeAttrName.size())); if (group_key_attr == nullptr && group_size_attr == nullptr) { // No CollectiveInfo is present. - return OkStatus(); + return absl::OkStatus(); } DCHECK(group_key_attr != nullptr) << "module attribute " << kGroupKeyAttrName @@ -742,7 +742,7 @@ Status PopulateCollectiveInfo(mlir::ModuleOp module_op, VLOG(2) << "Populating CollectiveInfo: group_key=" << group_key << " group_size=" << group_size; compilation_result->collective_info = {group_key, group_size, 0}; - return OkStatus(); + return absl::OkStatus(); } Status PopulateResultIOInfo( @@ -945,7 +945,7 @@ Status CompileGraphSetup( if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("compile_graph_setup_after", module_op); - return OkStatus(); + return absl::OkStatus(); } Status BuildHloFromModule(mlir::ModuleOp module_op, xla::XlaBuilder& builder, diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h index aaccd39a3db398..6a7a4c42ee3d52 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h @@ -147,7 +147,7 @@ Status PopulateResultIOInfo( // If enable_op_fallback is set to false, graph is legalized only if the graph // analysis for the graph is successful. Otherwise, an error is returned. ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") -StatusOr CompileMlirToXlaHlo( +absl::StatusOr CompileMlirToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, llvm::StringRef device_type, bool use_tuple_args, bool enable_op_fallback, bool use_return_tuple, bool use_resource_updates_for_aliases, @@ -163,7 +163,7 @@ StatusOr CompileMlirToXlaHlo( // If lower_to_xla_hlo is true then compiles down into XLA HLO, generates all // accompanying metadata and stores them in CompilationResult. ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") -StatusOr CompileSerializedMlirToXlaHlo( +absl::StatusOr CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, llvm::StringRef device_type, bool use_tuple_args, bool enable_op_fallback, const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc index 0355204506068c..a289e3a6d84148 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc @@ -130,11 +130,11 @@ Status PopulateInputOutputAliasing( output_to_input_alias[aliasing_output.getInt()] = arg_index; } - if (output_to_input_alias.empty()) return OkStatus(); + if (output_to_input_alias.empty()) return absl::OkStatus(); xla::HloModuleProto* module_proto = compilation_result->computation->mutable_proto(); - StatusOr program_shape_or_status = + absl::StatusOr program_shape_or_status = compilation_result->computation->GetProgramShape(); TF_RET_CHECK(program_shape_or_status.ok()); @@ -155,10 +155,10 @@ Status PopulateInputOutputAliasing( } } *module_proto->mutable_input_output_alias() = config.ToProto(); - return OkStatus(); + return absl::OkStatus(); } -bool failed(const tsl::Status& status) { return !status.ok(); } +bool failed(const absl::Status& status) { return !status.ok(); } // Transforms the given module to be suitable for export to TensorFlow GraphDef // and then exports all functions to the given library. @@ -203,7 +203,7 @@ Status PrepareAndExportToLibrary(mlir::ModuleOp module, flib_def); } -tsl::Status CompileTFFunctionWithoutMlir( +absl::Status CompileTFFunctionWithoutMlir( FunctionToHloArgs function_computation, const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, const XlaShapeLayoutHelpers::ShapeDeterminationFns @@ -230,7 +230,7 @@ tsl::Status CompileTFFunctionWithoutMlir( return comp_status; } -tsl::Status CompileMLIRTFFunction( +absl::Status CompileMLIRTFFunction( tpu::MlirToHloArgs mlir_computation, const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, const XlaShapeLayoutHelpers::ShapeDeterminationFns @@ -293,7 +293,7 @@ tsl::Status CompileMLIRTFFunction( } // namespace -tsl::Status CompileTensorflowGraphToHlo( +absl::Status CompileTensorflowGraphToHlo( const std::variant& computation, const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, const XlaShapeLayoutHelpers::ShapeDeterminationFns @@ -331,7 +331,7 @@ tsl::Status CompileTensorflowGraphToHlo( phase2_bridge_compilation_time->GetCell(kBridgePhase2Config) ->Add(timer.ElapsedCyclesInMilliseconds()); - return tsl::OkStatus(); + return absl::OkStatus(); } }; // namespace v1 diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h index 0a4d8709393ef9..c3f2a6d2d0d868 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h @@ -34,7 +34,7 @@ namespace v1 { // Compiles the given Tensorflow graph into xla::HLO. The result is in // compilation_result. If the input computation is in MLIR, it will be // converted to a Tensorflow graph. Otherwise, the graph compiler will be run. -tsl::Status CompileTensorflowGraphToHlo( +absl::Status CompileTensorflowGraphToHlo( const std::variant& computation, const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_funcs, diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc index fdff5122c3516e..06208be8fc5893 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc @@ -76,7 +76,7 @@ MlirToHloArgs CreateTestMlirToHloArgs(const char* module_str = kMlirModuleStr) { class CompileTFGraphTest : public ::testing::Test { public: - tsl::StatusOr CompileWithComputation( + absl::StatusOr CompileWithComputation( const std::variant computation) { XlaCompilationResult compilation_result; @@ -99,7 +99,7 @@ class CompileTFGraphTest : public ::testing::Test { XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns; - tsl::Status compilation_status = + absl::Status compilation_status = tensorflow::tf2xla::v1::CompileTensorflowGraphToHlo( computation, metadata_proto, use_tuple_args, shape_determination_fns, arg_shapes, &arg_core_mapping, diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc index 38393d3753146e..cad1edf2b89018 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc @@ -61,16 +61,16 @@ class TensorflowDialectToExecutorTest : public ::testing::Test { context_.loadAllAvailableDialects(); } - tsl::Status CreateMlirModule(std::string mlir_module_filename) { + absl::Status CreateMlirModule(std::string mlir_module_filename) { std::string mlir_module_path = TestDataPath() + mlir_module_filename; mlir_module_ = mlir::parseSourceFile(mlir_module_path, &context_); if (!mlir_module_) { - return tsl::Status( + return absl::Status( absl::StatusCode::kNotFound, absl::StrCat("Could not find MLIR module at ", mlir_module_path)); } - return tsl::OkStatus(); + return absl::OkStatus(); } DialectRegistry registry_; diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc index a5f64a91cd8cb4..14a9c1b1a99bff 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc @@ -66,16 +66,16 @@ class FunctionClusterTensorflowDialectTest : public ::testing::Test { context_.loadAllAvailableDialects(); } - tsl::Status CreateMlirModule(std::string mlir_module_filename) { + absl::Status CreateMlirModule(std::string mlir_module_filename) { std::string mlir_module_path = TestDataPath() + mlir_module_filename; mlir_module_ = mlir::parseSourceFile(mlir_module_path, &context_); if (!mlir_module_) { - return tsl::Status( + return absl::Status( absl::StatusCode::kNotFound, absl::StrCat("Could not find MLIR module at ", mlir_module_path)); } - return tsl::OkStatus(); + return absl::OkStatus(); } DialectRegistry registry_; diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc index d297e45b70e0bb..d84e4d8692a19d 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc @@ -51,8 +51,6 @@ namespace tensorflow { namespace tf2xla { namespace v2 { -using metrics::IncrementTfMlirBridgeSecondPhaseCounter; -using metrics::MlirBridgeSecondPhaseMetric; using tpu::FunctionToHloArgs; using tpu::MlirToHloArgs; using tpu::ShardingAndIndex; @@ -110,7 +108,7 @@ Status DumpHloCompilationResult(std::string_view name, XlaCompilationResult* compilation_result) { if (!VLOG_IS_ON(2) && !DEBUG_DATA_DUMPER()->ShouldDump(std::string(name), kDebugGroupMain)) { - return OkStatus(); + return absl::OkStatus(); } TF_ASSIGN_OR_RETURN( @@ -130,12 +128,12 @@ Status DumpHloCompilationResult(std::string_view name, tensorflow::DumpRawStringToFile(name, all_computations); - return OkStatus(); + return absl::OkStatus(); } } // namespace -tsl::StatusOr LegalizeMlirToHlo( +absl::StatusOr LegalizeMlirToHlo( const std::variant& computation, const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, llvm::StringRef device_type, @@ -185,7 +183,7 @@ tsl::StatusOr LegalizeMlirToHlo( } VLOG(1) << "Failed to compile MLIR computation to XLA HLO using Combined " - "MLIR and XlaBuilder Bridge. Falling back to MLIR tf2xla Bridge. " + "MLIR and XlaBuilder Bridge. Failed to lower to hlo." << combined_bridge_status.status(); tsl::error_logging::Log(kBridgeComponent, "TFXLA_API_V2_COMBINED_BRIDGE", combined_bridge_status.status().ToString()) @@ -201,7 +199,7 @@ tsl::StatusOr LegalizeMlirToHlo( VLOG(1) << "Successfully compiled MLIR computation to XLA HLO using MLIR " "tf2xla Bridge"; IncrementTfMlirBridgeSecondPhaseCounter( - MlirBridgeSecondPhaseMetric::kMlirWithFallbackModeSuccess); + metrics::MlirBridgeSecondPhaseMetric::kMlirWithFallbackModeSuccess); DumpHloCompilationResult("legalize_tf_mlir_bridge.hlo", compilation_result.get()) @@ -219,7 +217,7 @@ tsl::StatusOr LegalizeMlirToHlo( mlir_bridge_status.status().ToString()) .IgnoreError(); IncrementTfMlirBridgeSecondPhaseCounter( - MlirBridgeSecondPhaseMetric::kMlirWithFallbackModeFailure); + metrics::MlirBridgeSecondPhaseMetric::kMlirWithFallbackModeFailure); } return mlir_bridge_status.status(); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.h b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.h index c3dc6e18b92a1e..14a8271de171d1 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.h @@ -50,7 +50,7 @@ namespace v2 { // arg_core_mapping - Which args go on which cores. // per_core_arg_shapes - For each core, the shapes for each argument. // client - The Xla Compilation client. -tsl::StatusOr LegalizeMlirToHlo( +absl::StatusOr LegalizeMlirToHlo( const std::variant& computation, const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, llvm::StringRef device_type, diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc index 81b3b5a180eb93..0e7e61999d8f2b 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc @@ -101,7 +101,7 @@ static constexpr char kUnsupportedMlirBridgeModuleStr[] = R"( } })"; -tsl::StatusOr CompileMlirModule( +absl::StatusOr CompileMlirModule( const char* mlir_module_str, ConfigProto::Experimental::MlirBridgeRollout rollout_state) { MlirToHloArgs mlir_to_hlo_args; @@ -291,7 +291,7 @@ TEST(LegalizeTFTest, RecordsStreamzForNoMlirFallback) { std::vector> custom_legalization_passes; // This doesn't actually compile correctly. - tsl::StatusOr compile_result = + absl::StatusOr compile_result = LegalizeMlirToHlo(function_to_hlo_args, metadata_proto, use_tuple_args, /*device_type=*/"XLA_CPU_JIT", custom_legalization_passes, diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc index 9940a8d52c18e8..0c64dd3dcbe1a3 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc @@ -61,16 +61,16 @@ class TensorflowDialectToExecutorTest : public ::testing::Test { context_.loadAllAvailableDialects(); } - tsl::Status CreateMlirModule(std::string mlir_module_filename) { + absl::Status CreateMlirModule(std::string mlir_module_filename) { std::string mlir_module_path = TestDataPath() + mlir_module_filename; mlir_module_ = mlir::parseSourceFile(mlir_module_path, &context_); if (!mlir_module_) { - return tsl::Status( + return absl::Status( absl::StatusCode::kNotFound, absl::StrCat("Could not find MLIR module at ", mlir_module_path)); } - return tsl::OkStatus(); + return absl::OkStatus(); } DialectRegistry registry_; diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc index e289934b69fbe0..d0909c452a0325 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc @@ -189,8 +189,6 @@ void AddNonReplicatedBridgeClusteringPipelinePasses(OpPassManager& pm) { // inference. pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); - pm.addPass( - tensorflow::tf2xla::internal::CreateXlaOutlineEntryFunctionsPass()); // Encapsulate PartitionedCall ops within a cluster so that the composite // resource ops can be decomposed. pm.addPass(tensorflow::tf2xla::internal::CreateXlaClusterFormationPass()); @@ -200,12 +198,6 @@ void AddNonReplicatedBridgeClusteringPipelinePasses(OpPassManager& pm) { pm.addNestedPass(mlir::createCanonicalizerPass()); // Decompose resource ops. pm.addPass(mlir::TFDevice::CreateDecomposeResourceOpsInClusterPass()); - // TODO(b/267193636): Remove this flag when outside compilation - // for generic pipeline is landed. - if (tensorflow::GetMlirCommonFlags() - ->tf_mlir_enable_generic_outside_compilation) { - pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); - } // Run another shape inference pass because resource decomposition might have // created new partial types. Also, after dropping `shape_invariant` attribute // from While/WhileRegion ops within cluster would lead to more precise @@ -220,17 +212,6 @@ void AddNonReplicatedBridgeClusteringPipelinePasses(OpPassManager& pm) { // Lift resource operations out of device computation. This step needs to be // done after inlining. pm.addPass(mlir::TFDevice::CreateResourceOpLiftingPass()); - // TODO(b/267193636): Remove this flag when outside compilation - // for generic pipeline is landed. - if (tensorflow::GetMlirCommonFlags() - ->tf_mlir_enable_generic_outside_compilation) { - pm.addPass( - tensorflow::tf2xla::internal::CreateMarkOpsForOutsideCompilationPass()); - pm.addPass(tensorflow::tf2xla::internal:: - CreateExtractHeadTailOutsideCompilationPass()); - pm.addPass( - tensorflow::tf2xla::internal::CreateExtractOutsideCompilationPass()); - } // Outline clusters into cluster functions. pm.addPass(mlir::TFDevice::CreateClusterOutliningPass()); // Verifies clustering has conformed with the expected invariants diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc index d7d26cc1f7de50..9507ab371c1bcf 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc @@ -35,7 +35,7 @@ TEST(ClusteringBridgePassesTest, AddsNonTPUBridgePasses) { OpPassManager pass_manager; AddNonReplicatedBridgeClusteringPipelinePasses(pass_manager); - EXPECT_EQ(pass_manager.size(), 16); + EXPECT_EQ(pass_manager.size(), 15); } }; // namespace internal diff --git a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc index 4fd0c21d68331f..581f28fb4c9557 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc @@ -59,7 +59,7 @@ constexpr char kBridgeComponent[] = "TFXLABridge"; using tpu::MlirToHloArgs; using tpu::ShardingAndIndex; -tsl::StatusOr CompileFromMlirToXlaHlo( +absl::StatusOr CompileFromMlirToXlaHlo( bool lower_to_xla_hlo, const MlirToHloArgs& computation, const tpu::TPUCompileMetadataProto& metadata, llvm::StringRef device_type, const XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns, @@ -103,7 +103,7 @@ tsl::StatusOr CompileFromMlirToXlaHlo( return compiled_mlir; } -tsl::StatusOr LegalizeWithMlirBridge( +absl::StatusOr LegalizeWithMlirBridge( const tpu::MlirToHloArgs& computation, const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, llvm::StringRef device_type, @@ -119,7 +119,7 @@ tsl::StatusOr LegalizeWithMlirBridge( // Enabling op fallback also enables whole graph fallback if op by op // fallback failed. - tsl::StatusOr mlir_bridge_status = CompileFromMlirToXlaHlo( + absl::StatusOr mlir_bridge_status = CompileFromMlirToXlaHlo( /*lower_to_xla_hlo=*/true, computation, metadata, device_type, shape_determination_fns, use_tuple_args, compilation_result, custom_legalization_passes, arg_shapes, arg_core_mapping, diff --git a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.h b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.h index 91a613f8c6f848..014d9a4f35d31c 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.h +++ b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.h @@ -33,7 +33,7 @@ namespace internal { // result of running all the MLIR Bridge passes. If compile_to_xla_hlo is true // then those passes include all the Legalization to XLA HLO which is returned // in the compilation_result. -tsl::StatusOr CompileFromMlirToXlaHlo( +absl::StatusOr CompileFromMlirToXlaHlo( bool lower_to_xla_hlo, const tpu::MlirToHloArgs& computation, const tpu::TPUCompileMetadataProto& metadata, llvm::StringRef device_type, const XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns, @@ -45,7 +45,7 @@ tsl::StatusOr CompileFromMlirToXlaHlo( // Compiles a serialized MLIR module into XLA HLO, generates all accompanying // metadata and stores them in CompilationResult. -tsl::StatusOr LegalizeWithMlirBridge( +absl::StatusOr LegalizeWithMlirBridge( const tpu::MlirToHloArgs& computation, const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, llvm::StringRef device_type, diff --git a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir_test.cc index 904bfa85a3cc98..c8b9577e0daa6a 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir_test.cc @@ -49,8 +49,8 @@ static constexpr char kMlirModuleStr[] = R"( } })"; -tsl::StatusOr CompileMlirModule(bool compile_to_xla_hlo, - const char* module_str) { +absl::StatusOr CompileMlirModule(bool compile_to_xla_hlo, + const char* module_str) { MlirToHloArgs mlir_to_hlo_args; mlir_to_hlo_args.mlir_module = module_str; @@ -71,7 +71,7 @@ tsl::StatusOr CompileMlirModule(bool compile_to_xla_hlo, &per_core_arg_shapes); } -tsl::StatusOr LegalizeMlirModule( +absl::StatusOr LegalizeMlirModule( const char* module_str) { MlirToHloArgs mlir_to_hlo_args; mlir_to_hlo_args.mlir_module = module_str; diff --git a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc index ba1a20a27ef751..e26741e0877d7b 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc @@ -44,7 +44,7 @@ using metrics::IncrementTfMlirBridgeSecondPhaseCounter; using metrics::MlirBridgeSecondPhaseMetric; using tpu::MlirToHloArgs; -tsl::StatusOr LegalizeTfToHlo( +absl::StatusOr LegalizeTfToHlo( const tpu::MlirToHloArgs& computation, const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, llvm::StringRef device_type, @@ -57,8 +57,8 @@ tsl::StatusOr LegalizeTfToHlo( LOG_FIRST_N(INFO, 1) << "Compiling MLIR computation to XLA HLO using the " "Combined MLIR Tf2Xla Bridge."; - tsl::StatusOr mlir_compilation - = internal::CompileFromMlirToXlaHlo( + absl::StatusOr mlir_compilation = + internal::CompileFromMlirToXlaHlo( /*lower_to_xla_hlo=*/false, computation, metadata, device_type, shape_determination_fns, use_tuple_args, compilation_result, custom_legalization_passes, arg_shapes, arg_core_mapping, diff --git a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.h b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.h index c0a8283ed30605..664bd549ed360d 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.h +++ b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.h @@ -29,7 +29,7 @@ namespace internal { // Legalize the given MLIR module to XLA HLO using a combination of the MLIR // Bridge and XlaBuilder -tsl::StatusOr LegalizeTfToHlo( +absl::StatusOr LegalizeTfToHlo( const tpu::MlirToHloArgs& computation, const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, llvm::StringRef device_type, diff --git a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo_test.cc index ef5a82ed844728..686081c049e1b9 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo_test.cc @@ -72,7 +72,7 @@ static constexpr char kBadMlirModuleStr[] = R"( } })"; -tsl::StatusOr CompileMlirModule( +absl::StatusOr CompileMlirModule( const char* module_str) { MlirToHloArgs mlir_to_hlo_args; mlir_to_hlo_args.rollout_state = diff --git a/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc index a52e316d0bc334..840d4c971e7bb5 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc @@ -67,16 +67,16 @@ class LoggingHooksTest : public ::testing::Test { setenv("TF_DUMP_GRAPH_PREFIX", test_dir_.c_str(), 1); } - tsl::Status CreateMlirModule(std::string mlir_module_filename) { + absl::Status CreateMlirModule(std::string mlir_module_filename) { std::string mlir_module_path = TestDataPath() + mlir_module_filename; mlir_module_ = mlir::parseSourceFile(mlir_module_path, &context_); if (!mlir_module_) { - return tsl::Status( + return absl::Status( absl::StatusCode::kNotFound, absl::StrCat("Could not find MLIR module at ", mlir_module_path)); } - return tsl::OkStatus(); + return absl::OkStatus(); } DialectRegistry registry_; diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc index 2f8d85d7325298..7bf4c74e094af5 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc @@ -76,6 +76,36 @@ LogicalResult HasAttr( return failure(); } +// Check if the `graph` has parameter server jobs and resource variable +// arguments that are on parameter servers +bool HasPsWithResourceVariable(const Graph& graph) { + // Check parameter serverjobs and resource variable arguments that are + // on parameter servers. + const std::string jobType = "ps"; + const std::string nodeType = "_Arg"; + const std::string attrKey = "T"; + for (const Node* node : graph.nodes()) { + if (node->type_string() == nodeType) { + auto device_name = node->assigned_device_name(); + DeviceNameUtils::ParsedName device; + if (DeviceNameUtils::ParseFullName(device_name, &device) && + device.has_job && device.job == jobType) { + for (const auto& attr : node->attrs()) { + auto attr_key = attr.first; + auto attr_value = attr.second; + if (attr_key == attrKey && + attr_value.value_case() == AttrValue::kType && + attr_value.type() == DT_RESOURCE) { + return true; + break; + } + } + } + } + } + return false; +} + bool IsNonReplicatedGraph(const Graph& graph, const FunctionLibraryDefinition* function_library) { auto predicate = [](const Graph& graph) { @@ -111,22 +141,6 @@ bool IsReplicatedGraph(const Graph& graph, return HasAttr(graph, function_library, predicate).succeeded(); } -bool IsSingleCoreTpuGraph(const Graph& graph, - const FunctionLibraryDefinition* function_library) { - auto predicate = [](const Graph& graph) { - for (const Node* node : graph.nodes()) { - // _xla_compile_device_type=TPU is found in single-core TPU graphs. - auto attr = - node->attrs().FindByString(std::string(kCompileDeviceTypeAttr)); - if (attr && attr->s() == kTpuDevice) { - return true; - } - } - return false; - }; - return HasAttr(graph, function_library, predicate).succeeded(); -} - bool IsReplicatedGraph(mlir::ModuleOp module) { auto walk_result = module.walk([&](mlir::Operation* op) { // TODO(b/223677572): Once the scope for new compilation and replication @@ -144,25 +158,6 @@ bool IsReplicatedGraph(mlir::ModuleOp module) { return walk_result.wasInterrupted(); } -bool IsSingleCoreTPUGraph(mlir::ModuleOp module) { - auto walk_result = module.walk([&](mlir::Operation* op) { - // Check for ops with compile device type "TPU". This allows us to support - // TPU compilation without replication. Note that currently the compile - // device type is not set by default before bridge, only if eager context - // attribute `jit_compile_rewrite` is true. - // TODO(b/229028654): Remove string conversion once we have C++17. - const llvm::StringRef compile_device_type_attr_name( - kCompileDeviceTypeAttr.data(), kCompileDeviceTypeAttr.size()); - auto compilation_attr = - op->getAttrOfType(compile_device_type_attr_name); - if (compilation_attr && compilation_attr.getValue().str() == kTpuDevice) { - return mlir::WalkResult::interrupt(); - } - return mlir::WalkResult::advance(); - }); - return walk_result.wasInterrupted(); -} - // Traverses each node in the graph and check if any of them is // TPUPartitionedCall. If so, return true. Otherwise, return false. bool DoesGraphContainTPUPartitionedCall(const Graph& graph) { @@ -206,17 +201,17 @@ bool AreFunctionsFromFlibDefInference( bool IsSupportedByNonReplicatedBridge( const Graph& graph, const FunctionLibraryDefinition* function_library) { - return IsNonReplicatedGraph(graph, function_library); + return IsNonReplicatedGraph(graph, function_library) && + HasPsWithResourceVariable(graph); } bool IsSupportedByReplicatedBridge( const Graph& graph, const FunctionLibraryDefinition* function_library) { - return IsReplicatedGraph(graph, function_library) || - IsSingleCoreTpuGraph(graph, function_library); + return IsReplicatedGraph(graph, function_library); } bool IsSupportedByReplicatedBridge(mlir::ModuleOp module) { - return IsReplicatedGraph(module) || IsSingleCoreTPUGraph(module); + return IsReplicatedGraph(module); } bool HasTPUPartitionedCallOpInModule(mlir::ModuleOp module) { diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc index 699964e989f1e9..6cbc67d4ec395c 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc @@ -44,28 +44,19 @@ namespace tensorflow { namespace { -FunctionDef OuterXTimesTwo() { +// Produce a valid graph with a resource-type input. +FunctionDef PassThroughResource() { return FunctionDefHelper::Define( - // Name - "OuterXTimesTwo", - // Args - {"x: float"}, - // Return values - {"y: float"}, - // Attr def - {}, - {{{"y"}, - "StatefulPartitionedCall", - {"x"}, - {{"Tin", DataTypeSlice{DT_FLOAT}}, - {"Tout", DataTypeSlice{DT_FLOAT}}, - {"f", - FunctionDefHelper::FunctionRef("XTimesTwoFloat", {{"T", DT_FLOAT}})}, - {std::string(kMustCompileAttr), true}}}}); + /*function_name=*/"PassThroughResource", + /*arg_def=*/{"in: resource"}, + /*ret_def=*/{"out: resource"}, + /*attr_def=*/{}, + /*node_def=*/ + {{{"out"}, "Identity", {"in"}, {{"T", DataType::DT_RESOURCE}}}}); } TEST(IsSupportedByNonReplicatedBridge, NonReplicatedGraph) { - const FunctionDef& fd = test::function::XTimesTwo(); + const FunctionDef& fd = PassThroughResource(); FunctionDefLibrary flib; *flib.add_function() = fd; FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); @@ -76,7 +67,7 @@ TEST(IsSupportedByNonReplicatedBridge, NonReplicatedGraph) { ConfigProto config = ConfigProto(); Scope root = Scope::NewRootScope().ExitOnError(); - Output a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0); + Output a = ops::_Arg(root.WithOpName("A"), DT_RESOURCE, 0); std::vector inputs({NodeBuilder::NodeOut(a.node())}); Node* call; @@ -85,50 +76,21 @@ TEST(IsSupportedByNonReplicatedBridge, NonReplicatedGraph) { TF_ASSERT_OK( NodeBuilder("B", "StatefulPartitionedCall", &root.graph()->flib_def()) .Input(inputs) - .Attr("Tin", {DT_FLOAT}) - .Attr("Tout", {DT_FLOAT}) + .Attr("Tin", {DT_RESOURCE}) + .Attr("Tout", {DT_RESOURCE}) .Attr("f", f_name_attr) .Finalize(root.graph(), &call)); call->AddAttr(std::string(kMustCompileAttr), true); TF_ASSERT_OK(root.ToGraph(&graph)); - EXPECT_TRUE( - IsSupportedByNonReplicatedBridge(graph, /*function_library=*/nullptr)); -} - -// Checks that HasAttr actually goes through function library. -TEST(IsSupportedByNonReplicatedBridge, NonReplicatedFunctionLibrary) { - const FunctionDef& fd = OuterXTimesTwo(); - FunctionDefLibrary flib; - *flib.add_function() = fd; - FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); - Graph graph(OpRegistry::Global()); - graph.SetConstructionContext(ConstructionContext::kEagerRuntime); - tensorflow::set_tf2_execution(true); - - ConfigProto config = ConfigProto(); - Scope root = Scope::NewRootScope().ExitOnError(); - - Output a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0); - std::vector inputs({NodeBuilder::NodeOut(a.node())}); - - // Builds a call without compilation markers that calls a function with Xla - // clusters. - Node* call; - NameAttrList f_name_attr; - f_name_attr.set_name(fd.signature().name()); - TF_ASSERT_OK( - NodeBuilder("B", "StatefulPartitionedCall", &root.graph()->flib_def()) - .Input(inputs) - .Attr("Tin", {DT_FLOAT}) - .Attr("Tout", {DT_FLOAT}) - .Attr("f", f_name_attr) - .Finalize(root.graph(), &call)); + // Required for passing the PS server parameter check. + for (Node* node : graph.nodes()) { + node->set_assigned_device_name("/job:ps/replica:0/task:0/device:GPU:0"); + } - TF_ASSERT_OK(root.ToGraph(&graph)); EXPECT_TRUE( - IsSupportedByNonReplicatedBridge(graph, /*function_library=*/&flib_def)); + IsSupportedByNonReplicatedBridge(graph, /*function_library=*/nullptr)); } TEST(IsSupportedByReplicatedBridge, ReplicatedGraph) { @@ -164,39 +126,6 @@ TEST(IsSupportedByReplicatedBridge, ReplicatedGraph) { IsSupportedByReplicatedBridge(graph, /*function_library=*/nullptr)); } -TEST(IsSupportedByReplicatedBridge, SingleCoreTpuGraph) { - const FunctionDef& fd = test::function::XTimesTwo(); - FunctionDefLibrary flib; - *flib.add_function() = fd; - FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); - Graph graph(flib_def); - graph.SetConstructionContext(ConstructionContext::kEagerRuntime); - tensorflow::set_tf2_execution(true); - - ConfigProto config = ConfigProto(); - Scope root = Scope::NewRootScope().ExitOnError(); - - Output a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0); - std::vector inputs({NodeBuilder::NodeOut(a.node())}); - - Node* call; - NameAttrList f_name_attr; - f_name_attr.set_name(fd.signature().name()); - TF_ASSERT_OK( - NodeBuilder("B", "StatefulPartitionedCall", &root.graph()->flib_def()) - .Input(inputs) - .Attr("Tin", {DT_FLOAT}) - .Attr("Tout", {DT_FLOAT}) - .Attr("f", f_name_attr) - .Finalize(root.graph(), &call)); - call->AddAttr(std::string(kCompileDeviceTypeAttr), kTpuDevice); - - TF_ASSERT_OK(root.ToGraph(&graph)); - - EXPECT_TRUE( - IsSupportedByReplicatedBridge(graph, /*function_library=*/nullptr)); -} - TEST(IsSupportedByReplicatedBridge, ReplicatedModule) { const char* const code = R"mlir( func.func @entry_func_1(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { @@ -212,21 +141,6 @@ func.func @entry_func_1(%arg0: tensor) -> tensor attributes {tf.entry_ EXPECT_TRUE(IsSupportedByReplicatedBridge(*module)); } -TEST(IsSupportedByReplicatedBridge, SingleCoreTpuModule) { - const char* const code = R"mlir( -func.func @entry_func_1(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { - %0 = "tf.Identity"(%arg0) {_xla_compile_device_type = "TPU"} : (tensor) -> (tensor) - func.return %0 : tensor -} -)mlir"; - mlir::MLIRContext context; - context.loadDialect(); - mlir::OwningOpRef module = - mlir::parseSourceString(code, &context); - ASSERT_TRUE(module); - EXPECT_TRUE(IsSupportedByReplicatedBridge(*module)); -} - TEST(HasTPUPartitionedCallOpInModule, HasTPUPartitionedCallModule) { const char* const code = R"mlir( module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD index 4c6f68a3419656..9641e092815b58 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -33,7 +33,6 @@ cc_library( ":verify_clustering_pass", ":xla_broadcast", ":xla_cluster_formation", - ":xla_outline_entry_functions", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -135,7 +134,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_pass_inc_gen", "//tensorflow/core:framework", - "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", @@ -289,36 +287,6 @@ cc_library( ], ) -cc_library( - name = "xla_outline_entry_functions", - srcs = ["xla_outline_entry_functions.cc"], - textual_hdrs = [ - "clustering_passes.h.inc", - ], - deps = [ - ":clustering_passes_inc_gen", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:attribute_utils", - "//tensorflow/compiler/mlir/tensorflow:call_graph_util", - "//tensorflow/compiler/mlir/tensorflow:cluster_util", - "//tensorflow/compiler/mlir/tensorflow:string_util", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", - "//tensorflow/core:portable_gif_internal", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - ], -) - cc_library( name = "mark_ops_for_outside_compilation", srcs = ["mark_ops_for_outside_compilation.cc"], @@ -425,22 +393,6 @@ cc_library( ], ) -tf_cc_test( - name = "tpu_cluster_formation_test", - srcs = ["tpu_cluster_formation_test.cc"], - deps = [ - ":clustering_passes", - "//tensorflow/compiler/mlir/tf2xla/transforms:test_utils", - "//tensorflow/core/lib/monitoring:cell_reader", - "@com_google_googletest//:gtest_main", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "lowering_passes", hdrs = [ diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h index 5eba662c4bae60..fb6e32ac377b79 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h @@ -46,11 +46,6 @@ CreateExtractOutsideCompilationPass(); std::unique_ptr> CreateXlaClusterFormationPass(); -// Create a pass that rewrites entry functions with `_xla_compile_device` into a -// `tf.StatefulPartitionedCall` to the original function. -std::unique_ptr> -CreateXlaOutlineEntryFunctionsPass(); - // Creates a pass that marks unsupported ops in device cluster for outside // compilation. std::unique_ptr> diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td index 5b77ddd5afe991..2f617f7c154935 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td @@ -253,56 +253,6 @@ def XlaClusterFormationPass : Pass<"tf-xla-cluster-formation", "ModuleOp"> { let dependentDialects = ["mlir::tf_device::TensorFlowDeviceDialect"]; } -def XlaOutlineEntryFunctionsPass : Pass<"tf-xla-outline-entry-functions", "ModuleOp"> { - let summary = "Outline the body of an entry function into a call to the " - "original function body"; - let description = [{ - This pass adds support for top-level function with - `_xla_compile_device_type` attribute in MLIR generic pipeline. - It renames such a function, and creates a new function taking the original - name with a `tf.StatefulPartitionedCall` to the original function. It - allows the MLIR generic pipeline to handle such functions the same way it - handles other partitioned calls with the attribute. - - For example, the following code - - ```mlir - func.func @main(%arg0: tensor) -> tensor attributes {_xla_compile_device_type = "CPU", allow_soft_placement = true, tf.entry_function = {}} { - %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", _xla_compile_device_type = "CPU", f = @stateful_pcall_func} : (tensor) -> (tensor) - %1 = "tf.Const"() {value = dense<5> : tensor} : () -> tensor - %2 = "tf.Add"(%0, %1) : (tensor, tensor) -> (tensor) - func.return %2 : tensor - } - - func.func @stateful_pcall_func(%arg0: tensor) -> tensor { - func.return %arg0 : tensor - } - ``` - - will be replaced as - - ```mlir - func.func private @main_outlined(%arg0: tensor) -> tensor attributes {_xla_compile_device_type = "CPU", allow_soft_placement = true} { - %0 = "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func} : (tensor) -> tensor - %cst = "tf.Const"() {value = dense<5> : tensor} : () -> tensor - %1 = "tf.Add"(%0, %cst) : (tensor, tensor) -> tensor - return %1 : tensor - } - - func.func @main(%arg0: tensor) -> tensor attributes {_xla_compile_device_type = "CPU", tf.entry_function = {}} { - %0 = "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", allow_soft_placement = true, config = "", config_proto = "", executor_type = "", f = @main_outlined} : (tensor) -> tensor - return %0 : tensor - } - - func.func @stateful_pcall_func(%arg0: tensor) -> tensor { - func.return %arg0 : tensor - } - ``` - }]; - let constructor = "tensorflow::tf2xla::internal::CreateXlaOutlineEntryFunctionsPass()"; - let dependentDialects = ["mlir::tf_device::TensorFlowDeviceDialect"]; -} - def MarkOpsForOutsideCompilationPass : Pass<"tf-mark-ops-for-outside-compilation", "ModuleOp"> { let summary = "Marks ops in device cluster for outside compilation if they are unsupported on device."; diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc index 637369ed4fb6fc..b600c865661d58 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc @@ -59,7 +59,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/string_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" -#include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -96,8 +95,6 @@ constexpr llvm::StringRef kNoReplicationCluster = "__no_replication_cluster"; constexpr llvm::StringRef kBadReplicateInfoAttrMsg = "requires '_replication_info' string attribute"; -constexpr char kUseMlirBridge[] = "kUseMlirBridge"; - // Mapping for `_replication_info` attribute to TPUReplicateMetadata attributes. using MetadataMap = llvm::SmallDenseMap; @@ -108,15 +105,6 @@ using OpSetVector = llvm::SmallSetVector; // Mapping for `_replication_info` attribute to ops of a cluster. using ClusterMap = llvm::SmallDenseMap; -auto* jit_compile_single_core_tpu_count = - tensorflow::monitoring::Counter<1>::New( - /* metric name */ - "/tensorflow/core/jit_compile_single_core_tpu_count", - /* metric description */ - "Tracks if single core tpu support goes through the first " - "phase of the MLIR bridge", - /* metric field */ "use_mlir_bridge"); - #define GEN_PASS_DEF_TPUCLUSTERFORMATIONPASS #include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" @@ -943,7 +931,7 @@ void SetNoReplicationClusterAttrs(mlir::tf_device::ClusterOp cluster, LogicalResult FormClustersInBlock( Block* block, const mlir::TF::SideEffectAnalysis::Info& side_effect_analysis, - bool strict_clusters, bool& has_replication_in_module) { + bool strict_clusters) { MetadataMap metadata_map; LogicalResult result = CollectMetadata(block, &metadata_map); if (failed(result)) return result; @@ -956,8 +944,7 @@ LogicalResult FormClustersInBlock( if (!llvm::hasSingleElement(region)) return op.emitOpError("Expected single block region"); if (failed(FormClustersInBlock(®ion.front(), side_effect_analysis, - strict_clusters, - has_replication_in_module))) + strict_clusters))) return mlir::failure(); } } @@ -998,7 +985,6 @@ LogicalResult FormClustersInBlock( block, cluster_ops, results, cluster_successor_ops.getArrayRef()); if (!has_replication) { - has_replication_in_module = false; SetNoReplicationClusterAttrs(cluster, device_type, device); continue; } @@ -1034,12 +1020,12 @@ LogicalResult FormClustersInBlock( LogicalResult FormClustersInFunction( mlir::func::FuncOp func, const mlir::TF::SideEffectAnalysis::Info& side_effect_analysis, - bool strict_clusters, bool& has_replication_in_module) { + bool strict_clusters) { if (!llvm::hasSingleElement(func)) return func.emitOpError("Expecting a single block function"); if (failed(FormClustersInBlock(&func.front(), side_effect_analysis, - strict_clusters, has_replication_in_module))) + strict_clusters))) return mlir::failure(); // Remove TPUReplicatedInput and TPUReplicatedOutput nodes. @@ -1091,17 +1077,12 @@ void TPUClusterFormationPass::runOnOperation() { }); auto& side_effect_analysis = getAnalysis(); - bool has_replication_in_module = true; for (auto func : getOperation().getOps()) if (!func.isExternal() && failed(FormClustersInFunction( func, side_effect_analysis.GetAnalysisForFunc(func), - strict_clusters_, has_replication_in_module))) + strict_clusters_))) return signalPassFailure(); - - if (!has_replication_in_module) { - jit_compile_single_core_tpu_count->GetCell(kUseMlirBridge)->IncrementBy(1); - } } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation_test.cc deleted file mode 100644 index 640385f0156aae..00000000000000 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation_test.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OwningOpRef.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h" -#include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" -#include "tensorflow/core/lib/monitoring/cell_reader.h" -#include "tsl/platform/statusor.h" - -namespace tensorflow { -namespace tf2xla { -namespace internal { - -namespace { - -constexpr char kJitCompileSingleCoreTpuCount[] = - "/tensorflow/core/jit_compile_single_core_tpu_count"; -constexpr char kUseMlirBridge[] = "kUseMlirBridge"; -using mlir::mhlo::test::GetMlirModuleFromString; - -class TPUClusterFormationPassTest : public testing::Test { - protected: - void CreateModule(const char* module_string) { - TF_ASSERT_OK_AND_ASSIGN(module_, - GetMlirModuleFromString(module_string, &context_)); - bool strict_clusters = true; - pm_ = std::make_unique(&context_); - pm_->addPass(tensorflow::tf2xla::internal::CreateTPUClusterFormationPass( - strict_clusters)); - } - - mlir::LogicalResult Run() { return pm_->run(module_.get()); } - - private: - mlir::MLIRContext context_; - mlir::OwningOpRef module_; - std::unique_ptr pm_; -}; - -TEST_F(TPUClusterFormationPassTest, NonReplicatedTPU) { - monitoring::testing::CellReader feature_metric_reader( - kJitCompileSingleCoreTpuCount); - static constexpr char kMlirModuleStr[] = R"( - module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { - func.func @valid_compilation_cluster_no_replication() { - "tf.opA"() { _xla_compile_device_type = "TPU", is_stateless = true} : () -> () - "tf.opB"() { _xla_compile_device_type = "TPU", is_stateless = true} : () -> () - func.return - } - })"; - CreateModule(kMlirModuleStr); - auto result = Run(); - EXPECT_TRUE(result.succeeded()); - EXPECT_EQ(feature_metric_reader.Delta(kUseMlirBridge), 1); -} - -TEST_F(TPUClusterFormationPassTest, ReplicatedTPU) { - monitoring::testing::CellReader feature_metric_reader( - kJitCompileSingleCoreTpuCount); - static constexpr char kMlirModuleStr[] = R"( - module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { - func.func @interleaved_clusters(%arg0 : tensor) -> (tensor, tensor) { - "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate_1", device = "device_1", num_replicas = 1, topology = "topology_1"} : () -> () - %0 = "tf.opA"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "replicate_0", is_stateless = true} : (tensor) -> tensor - %1 = "tf.opB"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "replicate_1", is_stateless = true} : (tensor) -> tensor - %2 = "tf.opC"(%0) {_xla_compile_device_type = "TPU", _replication_info = "replicate_0", is_stateless = true} : (tensor) -> tensor - %3 = "tf.opD"(%1) {_xla_compile_device_type = "TPU", _replication_info = "replicate_1", is_stateless = true} : (tensor) -> tensor - "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate_0", device = "device_0", num_replicas = 1, topology = "topology_0"} : () -> () - func.return %2, %3 : tensor, tensor - } - })"; - CreateModule(kMlirModuleStr); - auto result = Run(); - EXPECT_TRUE(result.succeeded()); - EXPECT_EQ(feature_metric_reader.Delta(kUseMlirBridge), 0); -} - -} // namespace -} // namespace internal -} // namespace tf2xla -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_outline_entry_functions.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_outline_entry_functions.cc deleted file mode 100644 index 78d700f1514b7c..00000000000000 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_outline_entry_functions.cc +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.h" - -namespace tensorflow { -namespace tf2xla { -namespace internal { - -using mlir::ModuleOp; -using mlir::Operation; -using mlir::SymbolTable; - -#define GEN_PASS_DEF_XLAOUTLINEENTRYFUNCTIONSPASS -#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" - -inline constexpr char kOutlinedFuncSuffix[] = "_outlined"; - -// Outlines the body of an entry function with `_xla_compile_device_type` -// attribute and calls the outlined function with a -// `tf.StatefulPartitionedCall`. -struct XlaOutlineEntryFunctionsPass - : public impl::XlaOutlineEntryFunctionsPassBase< - XlaOutlineEntryFunctionsPass> { - void runOnOperation() override; -}; - -void RenameFunction(mlir::func::FuncOp func, const std::string &new_func_name, - SymbolTable &symtab) { - symtab.remove(func); - symtab.setSymbolName(func, new_func_name); - // Name conflicts are resolved automatically by SymbolTable class by attaching - // a unique counter value to the names. - symtab.insert(func); -} - -// Propagate compilation markers from the source to the destination. -void PropagateCompilationMarkers(Operation *src, Operation *dest) { - mlir::TF::CopyUnderscoredAttributes(src, dest); - if (src->hasAttr(mlir::TF::kAllowSoftPlacementAttr)) { - dest->setAttr(mlir::TF::kAllowSoftPlacementAttr, - src->getAttr(mlir::TF::kAllowSoftPlacementAttr)); - } -} - -mlir::func::FuncOp CreateWrapperFunction(mlir::func::FuncOp func, - const std::string &caller_name, - const std::string &callee_name) { - mlir::OpBuilder builder(func); - mlir::OpBuilder::InsertionGuard guard(builder); - mlir::FunctionType func_type = func.getFunctionType(); - mlir::Location loc = func.getLoc(); - auto wrapper_func = mlir::func::FuncOp::create(loc, caller_name, func_type); - mlir::Block *block = builder.createBlock(&wrapper_func.getBody()); - block->addArguments( - wrapper_func.getArgumentTypes(), - llvm::SmallVector(wrapper_func.getNumArguments(), loc)); - auto pcall_op = builder.create( - loc, func_type.getResults(), wrapper_func.getArguments(), - mlir::SymbolRefAttr::get(builder.getContext(), callee_name), - builder.getStringAttr(""), builder.getStringAttr(""), - builder.getStringAttr("")); - builder.create(loc, pcall_op.getResults()); - PropagateCompilationMarkers(func, pcall_op); - // Mark the original function private so it can be inlined. - func.setVisibility(mlir::func::FuncOp::Visibility::Private); - return wrapper_func; -} - -void ReplaceEntryFunction(mlir::func::FuncOp original_func, - mlir::func::FuncOp new_func) { - auto move_attr = [&](auto attr, Operation *src, Operation *dest) { - if (src->hasAttr(attr)) { - dest->setAttr(attr, src->getAttr(attr)); - src->removeAttr(attr); - } - }; - - for (const auto &attr : mlir::GetEntryFunctionAttributeNames()) { - move_attr(attr, original_func, new_func); - } - mlir::TF::CopyDeviceAndUnderscoredAttributes(original_func, new_func); -} - -mlir::func::FuncOp RewriteEntryFunctionWithCompilationMarkers( - mlir::func::FuncOp entry_func, SymbolTable &symtab) { - const std::string entry_func_name = entry_func.getSymName().str(), - outlined_entry_func_name = - entry_func_name + kOutlinedFuncSuffix; - RenameFunction(entry_func, outlined_entry_func_name, symtab); - auto new_entry_func = CreateWrapperFunction(entry_func, entry_func_name, - outlined_entry_func_name); - ReplaceEntryFunction(entry_func, new_entry_func); - symtab.insert(new_entry_func); - return new_entry_func; -} - -void XlaOutlineEntryFunctionsPass::runOnOperation() { - ModuleOp module = getOperation(); - SymbolTable symtab(module); - - llvm::SmallVector entry_funcs = GetEntryFunctions(module); - - for (auto &entry_func : entry_funcs) { - if (entry_func->hasAttr(mlir::TF::kCompileDeviceTypeAttr)) { - RewriteEntryFunctionWithCompilationMarkers(entry_func, symtab); - } - } -} - -std::unique_ptr> -CreateXlaOutlineEntryFunctionsPass() { - return std::make_unique(); -} - -} // namespace internal -} // namespace tf2xla -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/test_matchers_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/test_matchers_test.cc index 5ae2fcecedc3fc..696776f75b021c 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/test_matchers_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/test_matchers_test.cc @@ -43,14 +43,14 @@ template tsl::StatusOr success(T t) { return t; } -tsl::StatusOr success() { return kArbitraryIntResult; } +absl::StatusOr success() { return kArbitraryIntResult; } template tsl::StatusOr filtered(T t) { return tsl::StatusOr(tensorflow::CompileToHloGraphAnalysisFailedError()); } -tsl::StatusOr filtered() { return filtered(kArbitraryIntResult); } -tsl::StatusOr failed() { - return tsl::StatusOr(absl::InternalError("fail")); +absl::StatusOr filtered() { return filtered(kArbitraryIntResult); } +absl::StatusOr failed() { + return absl::StatusOr(absl::InternalError("fail")); } TEST(TestUtil, MatchesOk) { ASSERT_THAT(success(), IsOkOrFiltered()); } diff --git a/tensorflow/compiler/mlir/tf2xla/tests/adjust-layout.mlir b/tensorflow/compiler/mlir/tf2xla/tests/adjust-layout.mlir index dde3c4213e6146..70b42f32392ca1 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/adjust-layout.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/adjust-layout.mlir @@ -4,7 +4,9 @@ func.func @infeed_dequeue_tuple() -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) // CHECK: [[TOKEN:%.*]] = mhlo.create_token : !mhlo.token %0 = "mhlo.create_token"() : () -> !mhlo.token - // CHECK: [[INFEED:%.*]]:3 = "mhlo.infeed"([[TOKEN]]) {infeed_config = "", layout = [{{\[1, 3, 2, 0], \[1, 2, 0]}}]} : (!mhlo.token) -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>, !mhlo.token) + // CHECK: [[INFEED:%.*]]:3 = "mhlo.infeed"([[TOKEN]]) <{ + // CHECK-SAME{LITERAL}: infeed_config = "", layout = [[1, 3, 2, 0], [1, 2, 0]] + // CHECK-SAME: }> : (!mhlo.token) -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>, !mhlo.token) %1:3 = "mhlo.infeed"(%0) {infeed_config = ""} : (!mhlo.token) -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>, !mhlo.token) // CHECK: return [[INFEED]]#0, [[INFEED]]#1 diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-BatchMatMulV2.mlir index ebefc1ca1ab140..f62e9a140e83d9 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-BatchMatMulV2.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-BatchMatMulV2.mlir @@ -15,7 +15,7 @@ func.func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32 // CHECK: [[BCASTHEAD:%.*]] = shape.broadcast [[LHSHEAD]], [[RHSHEAD]] // CHECK: [[LHSBCASTSHAPE:%.*]] = shape.concat [[BCASTHEAD]], [[LHSTAIL]] // CHECK: [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]] -// CHECK: [[LHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32> +// CHECK: [[LHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) <{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32> // CHECK: [[RHSBCASTSHAPE:%.*]] = shape.concat [[BCASTHEAD]], [[RHSTAIL]] // CHECK: [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHS]]) // CHECK: return [[RESULT]] : tensor<3x4x4xf32> @@ -27,8 +27,8 @@ func.func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32 func.func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> { // CHECK-LABEL: func @batchmatmulv2_lhs_batch -// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} -// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) { +// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> +// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) <{ // CHECK-SAME: lhs_batching_dimensions = [0] // CHECK-SAME: rhs_batching_dimensions = [0] // CHECK-SAME: lhs_contracting_dimensions = [2] @@ -39,8 +39,8 @@ func.func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf func.func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { // CHECK-LABEL: func @batchmatmulv2_rhs_batch -// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} -// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) { +// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> +// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) <{ // CHECK-SAME: lhs_batching_dimensions = [0] // CHECK-SAME: rhs_batching_dimensions = [0] // CHECK-SAME: lhs_contracting_dimensions = [2] @@ -51,7 +51,7 @@ func.func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf func.func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-LABEL: func @batchmatmulv2_dynamic -// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) { +// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) <{ // CHECK-SAME: lhs_batching_dimensions = [0] // CHECK-SAME: rhs_batching_dimensions = [0] // CHECK-SAME: lhs_contracting_dimensions = [2] @@ -62,7 +62,7 @@ func.func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor, %arg1: tensor<4x2xf32>) -> tensor<5x4xf32> { // CHECK-LABEL: func @batchmatmulv2_adj_real -// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) { +// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) <{ // CHECK-NOT: lhs_batching_dimensions // CHECK-NOT: rhs_batching_dimensions // CHECK-SAME: lhs_contracting_dimensions = [0] diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir index 01dc4701923675..da64452a3039f8 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir @@ -23,7 +23,7 @@ func.func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { // patterns unambiguous and more interesting (once broadcastable trait is // fixed upstream). func.func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> // CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1 %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> func.return %0: tensor<1x2xi32> @@ -33,7 +33,7 @@ func.func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor // TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream // broadcastable bug is fixed (helps make the CHECK matching unambiguous) func.func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} + // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}> // CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1 %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> func.return %0: tensor<4x4x4x4xi32> @@ -48,8 +48,8 @@ func.func @add_dynamic(%arg0: tensor, %arg1: tensor) -> tensor, tensor<2xindex> -> tensor<2xindex> - // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-NEXT: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> + // CHECK-NEXT: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor // CHECK-NEXT: shape.assuming_yield %[[RESULT]] %0 = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -149,7 +149,7 @@ func.func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> t // CHECK-LABEL: func @broadcast_shift_right_unsigned func.func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> { - // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xui8>) -> tensor<2x4xui8> + // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<4xui8>) -> tensor<2x4xui8> // CHECK: mhlo.shift_right_logical %[[BROADCAST]], %arg1 : tensor<2x4xui8> %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> func.return %0 : tensor<2x4xui8> @@ -248,8 +248,8 @@ func.func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor, tensor -> tensor // NOT-CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor.cast %[[RESULT_SHAPE]] : tensor to tensor<1xindex> - // NOT-CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // NOT-CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // NOT-CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> + // NOT-CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> // NOT-CHECK-NEXT: %[[RESULT:.+]] = mhlo.compare EQ, %[[LHS_BCAST]], %[[RHS_BCAST]] // NOT-CHECK-NEXT: shape.assuming_yield %[[RESULT]] %0 = "tf.Equal"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor @@ -258,7 +258,7 @@ func.func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> // CHECK-NEXT: mhlo.compare EQ, %[[LHS_BCAST]], %arg1 %0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0: tensor<1x2xi1> @@ -329,7 +329,7 @@ func.func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { // CHECK-LABEL: func @broadcast_greater func.func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> // CHECK-NEXT: mhlo.compare GT, %[[LHS_BCAST]], %arg1 %0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0: tensor<1x2xi1> @@ -344,8 +344,8 @@ func.func @greater_dynamic(%arg0: tensor, %arg1: tensor) -> tensor // CHECK-DAG: %[[LHS_SHAPE1:.+]] = shape.shape_of %arg0 // CHECK-DAG: %[[RHS_SHAPE1:.+]] = shape.shape_of %arg1 // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[LHS_SHAPE1]], %[[RHS_SHAPE1]] : tensor<1xindex>, tensor<1xindex> -> tensor<1xindex> - // CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> + // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> // CHECK-NEXT: mhlo.compare GT, %[[LHS_BCAST]], %[[RHS_BCAST]] %0 = "tf.Greater"(%arg0, %arg1) : (tensor, tensor) -> tensor func.return %0: tensor diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-collective.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-collective.mlir index fa0bc94a980eb0..9c5653c61b9703 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-collective.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-collective.mlir @@ -8,9 +8,9 @@ func.func @all_reduce_cross_replica(%input: tensor) -> tensor { %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32> // CHECK: "mhlo.all_reduce" - // CHECK: mhlo.add // CHECK{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> // CHECK-NOT: channel_handle + // CHECK: mhlo.add %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Add", mode = "CrossReplica"} : (tensor, tensor<2x1xi32>) -> tensor func.return %0 : tensor } @@ -24,16 +24,16 @@ func.func @all_reduce_cross_replica(%input: tensor) -> tensor { func.func @all_reduce_cross_replica_and_partition(%input: tensor) -> tensor { %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32> // CHECK: "mhlo.all_reduce" + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> // CHECK: mhlo.add // CHECK: mhlo.return - // CHECK-NEXT: channel_handle = #mhlo.channel_handle - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Add", mode = "CrossReplicaAndPartition"} : (tensor, tensor<2x1xi32>) -> tensor // CHECK: "mhlo.all_reduce" + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> // CHECK: mhlo.add // CHECK: mhlo.return - // CHECK-NEXT: channel_handle = #mhlo.channel_handle - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> %1 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Add", mode = "CrossReplicaAndPartition"} : (tensor, tensor<2x1xi32>) -> tensor %2 = "tf.Add"(%0, %1) : (tensor, tensor) -> tensor func.return %2 : tensor @@ -110,16 +110,16 @@ func.func @collective_reduce_v2(%input: tensor) -> tensor { %group_size = "tf.Const"() { value = dense<2> : tensor } : () -> tensor %instance_key = "tf.Const"() { value = dense<3> : tensor } : () -> tensor // CHECK: "mhlo.all_reduce" + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> // CHECK: mhlo.add // CHECK: mhlo.return - // CHECK-NEXT: channel_handle = #mhlo.channel_handle - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor, tensor, tensor, tensor) -> tensor // CHECK: "mhlo.all_reduce" + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> // CHECK: mhlo.add // CHECK: mhlo.return - // CHECK-NEXT: channel_handle = #mhlo.channel_handle - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> %1 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor, tensor, tensor, tensor) -> tensor %2 = "tf.Add"(%0, %1) : (tensor, tensor) -> tensor func.return %2 : tensor @@ -133,10 +133,10 @@ func.func @collective_reduce_v2_add_id(%input: tensor) -> tensor { %group_size = "tf.Const"() { value = dense<2> : tensor } : () -> tensor %instance_key = "tf.Const"() { value = dense<3> : tensor } : () -> tensor // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> // CHECK: mhlo.add // CHECK: mhlo.return - // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> - // CHECK-NEXT: return %[[REDUCE]] + // CHECK: return %[[REDUCE]] %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -147,10 +147,10 @@ func.func @collective_reduce_v2_max_id(%input: tensor) -> tensor { %group_size = "tf.Const"() { value = dense<2> : tensor } : () -> tensor %instance_key = "tf.Const"() { value = dense<3> : tensor } : () -> tensor // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> // CHECK: mhlo.maximum // CHECK: mhlo.return - // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> - // CHECK-NEXT: return %[[REDUCE]] + // CHECK: return %[[REDUCE]] %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Max", final_op = "Id"} : (tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -161,10 +161,10 @@ func.func @collective_reduce_v2_min_id(%input: tensor) -> tensor { %group_size = "tf.Const"() { value = dense<2> : tensor } : () -> tensor %instance_key = "tf.Const"() { value = dense<3> : tensor } : () -> tensor // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> // CHECK: mhlo.minimum // CHECK: mhlo.return - // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> - // CHECK-NEXT: return %[[REDUCE]] + // CHECK: return %[[REDUCE]] %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Min", final_op = "Id"} : (tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -175,10 +175,10 @@ func.func @collective_reduce_v2_mul_id(%input: tensor) -> tensor { %group_size = "tf.Const"() { value = dense<2> : tensor } : () -> tensor %instance_key = "tf.Const"() { value = dense<3> : tensor } : () -> tensor // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> // CHECK: mhlo.mul // CHECK: mhlo.return - // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> - // CHECK-NEXT: return %[[REDUCE]] + // CHECK: return %[[REDUCE]] %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Mul", final_op = "Id"} : (tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -190,10 +190,10 @@ func.func @collective_reduce_v2_add_div(%input: tensor) -> tensor { %instance_key = "tf.Const"() { value = dense<3> : tensor } : () -> tensor // CHECK: %[[GROUP_SIZE:.*]] = mhlo.constant dense<2.000000e+00> // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> // CHECK: mhlo.add // CHECK: mhlo.return - // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> - // CHECK-NEXT: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]] + // CHECK: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]] // CHECK-NEXT: return %[[RESULT]] %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Div"} : (tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor @@ -206,10 +206,10 @@ func.func @collective_reduce_v2_max_div(%input: tensor) -> tensor { %instance_key = "tf.Const"() { value = dense<3> : tensor } : () -> tensor // CHECK: %[[GROUP_SIZE:.*]] = mhlo.constant dense<2.000000e+00> // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> // CHECK: mhlo.maximum // CHECK: mhlo.return - // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> - // CHECK-NEXT: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]] + // CHECK: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]] // CHECK-NEXT: return %[[RESULT]] %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Max", final_op = "Div"} : (tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor @@ -222,10 +222,10 @@ func.func @collective_reduce_v2_min_div(%input: tensor) -> tensor { %instance_key = "tf.Const"() { value = dense<3> : tensor } : () -> tensor // CHECK: %[[GROUP_SIZE:.*]] = mhlo.constant dense<2.000000e+00> // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> // CHECK: mhlo.minimum // CHECK: mhlo.return - // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> - // CHECK-NEXT: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]] + // CHECK: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]] // CHECK-NEXT: return %[[RESULT]] %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Min", final_op = "Div"} : (tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor @@ -238,10 +238,10 @@ func.func @collective_reduce_v2_mul_div(%input: tensor) -> tensor { %instance_key = "tf.Const"() { value = dense<3> : tensor } : () -> tensor // CHECK: %[[GROUP_SIZE:.*]] = mhlo.constant dense<2.000000e+00> // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> // CHECK: mhlo.mul // CHECK: mhlo.return - // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> - // CHECK-NEXT: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]] + // CHECK: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]] // CHECK-NEXT: return %[[RESULT]] %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Mul", final_op = "Div"} : (tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-prefer-tf2xla.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-prefer-tf2xla.mlir index c1bff70e2e4ff6..49a26ef844623f 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-prefer-tf2xla.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-prefer-tf2xla.mlir @@ -69,22 +69,22 @@ func.func @random_uniform_without_seeds(%arg0: tensor<4xi32>) -> tensor<32x12x12 func.func @random_uniform_with_seeds(%arg0: tensor<4xi32>) -> tensor<32x12x12x64xf32> { // CHECK: %0 = mhlo.constant dense<[32, 12, 12, 64]> : tensor<4xi32> // CHECK-NEXT: %1 = mhlo.constant dense<[32, 12, 12, 64]> : tensor<4xi32> - // CHECK-NEXT: %2 = "mhlo.slice"(%1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1xi32> + // CHECK-NEXT: %2 = "mhlo.slice"(%1) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<4xi32>) -> tensor<1xi32> // CHECK-NEXT: %3 = mhlo.reshape %2 : (tensor<1xi32>) -> tensor // CHECK-NEXT: %4 = mhlo.convert %3 : tensor - // CHECK-NEXT: %5 = "mhlo.slice"(%1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1xi32> + // CHECK-NEXT: %5 = "mhlo.slice"(%1) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<4xi32>) -> tensor<1xi32> // CHECK-NEXT: %6 = mhlo.reshape %5 : (tensor<1xi32>) -> tensor // CHECK-NEXT: %7 = mhlo.convert %6 : tensor - // CHECK-NEXT: %8 = "mhlo.slice"(%1) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1xi32> + // CHECK-NEXT: %8 = "mhlo.slice"(%1) <{limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<4xi32>) -> tensor<1xi32> // CHECK-NEXT: %9 = mhlo.reshape %8 : (tensor<1xi32>) -> tensor // CHECK-NEXT: %10 = mhlo.convert %9 : tensor - // CHECK-NEXT: %11 = "mhlo.slice"(%1) {limit_indices = dense<4> : tensor<1xi64>, start_indices = dense<3> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1xi32> + // CHECK-NEXT: %11 = "mhlo.slice"(%1) <{limit_indices = dense<4> : tensor<1xi64>, start_indices = dense<3> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<4xi32>) -> tensor<1xi32> // CHECK-NEXT: %12 = mhlo.reshape %11 : (tensor<1xi32>) -> tensor // CHECK-NEXT: %13 = mhlo.convert %12 : tensor // CHECK-NEXT: %14 = mhlo.constant dense<0.000000e+00> : tensor // CHECK-NEXT: %15 = mhlo.constant dense<1.000000e+00> : tensor // CHECK-NEXT: %16 = mhlo.constant dense<[32, 12, 12, 64]> : tensor<4xi64> - // CHECK-NEXT: %17 = "mhlo.rng"(%14, %15, %16) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<4xi64>) -> tensor<32x12x12x64xf32> + // CHECK-NEXT: %17 = "mhlo.rng"(%14, %15, %16) <{rng_distribution = #mhlo.rng_distribution}> : (tensor, tensor, tensor<4xi64>) -> tensor<32x12x12x64xf32> %cst = "tf.Const"() {value = dense<[32, 12, 12, 64]> : tensor<4xi32>} : () -> tensor<4xi32> %0 = "tf.RandomUniform"(%cst) {seed = 87654321 : i64, seed2 = 0 : i64} : (tensor<4xi32>) -> tensor<32x12x12x64xf32> // CHECK: return %17 : tensor<32x12x12x64xf32> @@ -103,7 +103,7 @@ func.func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> // CHECK: %[[RESHAPED_START2:.*]] = mhlo.reshape %[[SLICED_START2]] : (tensor<1xi64>) -> tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> // CHECK: return %[[RESULT]] : tensor<1x4xi32> %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> @@ -114,26 +114,26 @@ func.func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> func.func @slice_variable_start_negsize(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi32>) -> tensor<1x4xi32> { // CHECK: %0 = mhlo.constant dense<[1, -1]> : tensor<2xi32> // CHECK-NEXT: %1 = mhlo.constant dense<[1, -1]> : tensor<2xi32> - // CHECK-NEXT: %2 = "mhlo.slice"(%1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: %2 = "mhlo.slice"(%1) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> // CHECK-NEXT: %3 = mhlo.reshape %2 : (tensor<1xi32>) -> tensor // CHECK-NEXT: %4 = mhlo.constant dense<[1, -1]> : tensor<2xi32> - // CHECK-NEXT: %5 = "mhlo.slice"(%4) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: %5 = "mhlo.slice"(%4) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> // CHECK-NEXT: %6 = mhlo.reshape %5 : (tensor<1xi32>) -> tensor // CHECK-NEXT: %7 = mhlo.constant dense<3> : tensor - // CHECK-NEXT: %8 = "mhlo.set_dimension_size"(%arg0, %7) {dimension = 0 : i64} : (tensor<3x4xi32>, tensor) -> tensor<3x4xi32> + // CHECK-NEXT: %8 = "mhlo.set_dimension_size"(%arg0, %7) <{dimension = 0 : i64}> : (tensor<3x4xi32>, tensor) -> tensor<3x4xi32> // CHECK-NEXT: %9 = mhlo.constant dense<4> : tensor - // CHECK-NEXT: %10 = "mhlo.set_dimension_size"(%8, %9) {dimension = 1 : i64} : (tensor<3x4xi32>, tensor) -> tensor<3x4xi32> + // CHECK-NEXT: %10 = "mhlo.set_dimension_size"(%8, %9) <{dimension = 1 : i64}> : (tensor<3x4xi32>, tensor) -> tensor<3x4xi32> // CHECK-NEXT: %11 = mhlo.constant dense<0> : tensor - // CHECK-NEXT: %12 = "mhlo.pad"(%10, %11) {edge_padding_high = dense<[3, 4]> : tensor<2xi64>, edge_padding_low = dense<0> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<3x4xi32>, tensor) -> tensor<6x8xi32> - // CHECK-NEXT: %13 = "mhlo.slice"(%arg1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: %12 = "mhlo.pad"(%10, %11) <{edge_padding_high = dense<[3, 4]> : tensor<2xi64>, edge_padding_low = dense<0> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor) -> tensor<6x8xi32> + // CHECK-NEXT: %13 = "mhlo.slice"(%arg1) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> // CHECK-NEXT: %14 = mhlo.reshape %13 : (tensor<1xi32>) -> tensor - // CHECK-NEXT: %15 = "mhlo.slice"(%arg1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: %15 = "mhlo.slice"(%arg1) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> // CHECK-NEXT: %16 = mhlo.reshape %15 : (tensor<1xi32>) -> tensor - // CHECK-NEXT: %17 = "mhlo.dynamic_slice"(%12, %14, %16) {slice_sizes = dense<[3, 4]> : tensor<2xi64>} : (tensor<6x8xi32>, tensor, tensor) -> tensor<3x4xi32> - // CHECK-NEXT: %18 = "mhlo.slice"(%17) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> + // CHECK-NEXT: %17 = "mhlo.dynamic_slice"(%12, %14, %16) <{slice_sizes = dense<[3, 4]> : tensor<2xi64>}> : (tensor<6x8xi32>, tensor, tensor) -> tensor<3x4xi32> + // CHECK-NEXT: %18 = "mhlo.slice"(%17) <{limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<3x4xi32>) -> tensor<1x4xi32> // CHECK-NEXT: %19 = mhlo.constant dense<4> : tensor // CHECK-NEXT: %20 = mhlo.subtract %19, %16 : tensor - // CHECK-NEXT: %21 = "mhlo.set_dimension_size"(%18, %20) {dimension = 1 : i64} : (tensor<1x4xi32>, tensor) -> tensor<1x?xi32, #mhlo.type_extensions> + // CHECK-NEXT: %21 = "mhlo.set_dimension_size"(%18, %20) <{dimension = 1 : i64}> : (tensor<1x4xi32>, tensor) -> tensor<1x?xi32, #mhlo.type_extensions> // CHECK-NEXT: %cast = tensor.cast %21 : tensor<1x?xi32, #mhlo.type_extensions> to tensor<1x4xi32> // CHECK-NEXT: return %cast : tensor<1x4xi32> %sizes = "tf.Const"() {value = dense<[1, -1]> : tensor<2xi32>} : () -> (tensor<2xi32>) @@ -178,7 +178,7 @@ func.func @fused_conv2d(%input: tensor<1x300x300x40xi8>, // CHECK-NEXT: %[[v1:.*]] = mhlo.constant dense<2.000000e+00> : tensor // CHECK-NEXT: %[[v2:.*]] = mhlo.constant dense<2.000000e+00> : tensor // CHECK-NEXT: %[[v3:.*]] = mhlo.constant dense<-1.280000e+02> : tensor - // CHECK-NEXT: %[[v4:.*]] = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x300x300x40xf32> + // CHECK-NEXT: %[[v4:.*]] = "mhlo.broadcast_in_dim"(%3) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<1x300x300x40xf32> // CHECK-NEXT: %[[v5:.*]] = mhlo.convert %arg0 : (tensor<1x300x300x40xi8>) -> tensor<1x300x300x40xf32> // CHECK-NEXT: %[[v6:.*]] = mhlo.convert %arg1 : (tensor<3x3x40x40xi8>) -> tensor<3x3x40x40xf32> // CHECK: %[[v7:.*]] = mhlo.convolution(%[[v5]], %[[v6]]) @@ -188,16 +188,16 @@ func.func @fused_conv2d(%input: tensor<1x300x300x40xi8>, // CHECK-SAME: feature_group_count = 1 // CHECK-NEXT: %[[v8:.*]] = mhlo.convert %7 : tensor<1x300x300x40xf32> // CHECK-NEXT: %[[v9:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK-NEXT: %[[v10:.*]] = "mhlo.broadcast_in_dim"(%9) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x300x300x40xf32> + // CHECK-NEXT: %[[v10:.*]] = "mhlo.broadcast_in_dim"(%9) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<1x300x300x40xf32> // CHECK-NEXT: %11 = mhlo.multiply %8, %10 : tensor<1x300x300x40xf32> // CHECK-NEXT: %12 = mhlo.convert %arg2 : tensor<40xf32> - // CHECK-NEXT: %13 = "mhlo.broadcast_in_dim"(%12) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<40xf32>) -> tensor<1x300x300x40xf32> + // CHECK-NEXT: %13 = "mhlo.broadcast_in_dim"(%12) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<40xf32>) -> tensor<1x300x300x40xf32> // CHECK-NEXT: %14 = mhlo.add %11, %13 : tensor<1x300x300x40xf32> // CHECK-NEXT: %15 = mhlo.constant dense<0.000000e+00> : tensor - // CHECK-NEXT: %16 = "mhlo.broadcast_in_dim"(%15) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x300x300x40xf32> + // CHECK-NEXT: %16 = "mhlo.broadcast_in_dim"(%15) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<1x300x300x40xf32> // CHECK-NEXT: %17 = mhlo.maximum %14, %16 : tensor<1x300x300x40xf32> // CHECK-NEXT: %18 = mhlo.constant dense<1.270000e+02> : tensor - // CHECK-NEXT: %19 = "mhlo.broadcast_in_dim"(%18) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x300x300x40xf32> + // CHECK-NEXT: %19 = "mhlo.broadcast_in_dim"(%18) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<1x300x300x40xf32> // CHECK-NEXT: %20 = mhlo.clamp %4, %17, %19 : tensor<1x300x300x40xf32> // CHECK-NEXT: %21 = mhlo.round_nearest_even %20 : tensor<1x300x300x40xf32> // CHECK-NEXT: %22 = mhlo.convert %21 : (tensor<1x300x300x40xf32>) -> tensor<1x300x300x40xi8> diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir index 6004400ffe8802..328a00ce59bbec 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir @@ -256,7 +256,7 @@ func.func @uniform_quantized_dot(%input: tensor) -> tensor : tensor } : () -> tensor // CHECK-DAG: %[[RHS:.*]] = mhlo.constant() - // CHECK-SAME{LITERAL}: {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi8>} : () -> tensor<2x2x!quant.uniform> + // CHECK-SAME{LITERAL}: <{value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi8>}> : () -> tensor<2x2x!quant.uniform> // CHECK-DAG: %[[LHS:.*]] = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> // CHECK-DAG: %[[CONVERT_1:.*]] = mhlo.bitcast_convert %[[LHS]] : (tensor>) -> tensor // CHECK-DAG: %[[CONVERT_2:.*]] = mhlo.bitcast_convert %[[CONVERT_1]] : (tensor) -> tensor> @@ -306,7 +306,7 @@ func.func @uniform_quantized_convolution(%input: tensor<1x6x6x3xf32>) -> tensor< %output_zps = "tf.Const"() { value = dense<5> : tensor } : () -> tensor // CHECK-DAG: %[[RHS:.*]] = mhlo.constant() - // CHECK-SAME{LITERAL}: {value = dense<127> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform> + // CHECK-SAME{LITERAL}: <{value = dense<127> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform> // CHECK-DAG: %[[LHS:.*]] = mhlo.uniform_quantize %arg0 : (tensor<1x6x6x3xf32>) -> tensor<1x6x6x3x!quant.uniform> // CHECK-DAG: %[[CONVERT_1:.*]] = mhlo.bitcast_convert %[[LHS]] : (tensor<1x6x6x3x!quant.uniform>) -> tensor<1x6x6x3xi8> // CHECK-DAG: %[[CONVERT_2:.*]] = mhlo.bitcast_convert %[[CONVERT_1]] : (tensor<1x6x6x3xi8>) -> tensor<1x6x6x3x!quant.uniform> @@ -367,7 +367,7 @@ func.func @uniform_quantized_add(%arg0: tensor<3x2x!tf_type.qint32>) -> tensor<3 %output_zps = "tf.Const"() { value = dense<4> : tensor } : () -> tensor // CHECK-DAG: %[[LHS:.*]] = mhlo.bitcast_convert %arg0 : (tensor<3x2xi32>) -> tensor<3x2x!quant.uniform> - // CHECK-DAG: %[[RHS:.*]] = mhlo.constant() {value = dense<127> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> + // CHECK-DAG: %[[RHS:.*]] = mhlo.constant() <{value = dense<127> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> // CHECK: %[[RES:.*]] = chlo.broadcast_add %[[LHS]], %[[RHS]] {broadcast_dimensions = array} : // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) // CHECK-SAME: -> tensor<3x2x!quant.uniform> @@ -407,7 +407,7 @@ func.func @uniform_quantized_clip_by_value(%input: tensor<3x2xf32>) -> tensor<3x %zps = "tf.Const"() { value = dense<4> : tensor<2xi32> } : () -> tensor<2xi32> // tensor_proto that points to dense<127> of type !tf_type.qint32. - // CHECK-DAG: %[[MIN_MAX:.*]] = mhlo.constant() {value = dense<127> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> + // CHECK-DAG: %[[MIN_MAX:.*]] = mhlo.constant() <{value = dense<127> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> %min = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> %max = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir index aabc9d471f8385..f1fb2fec85722c 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir @@ -123,12 +123,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-LABEL: binary_op_broadcast func.func @binary_op_broadcast(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> tensor<4x4x4xf32> { - // CHECK: %[[BROADCAST0:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4x1xf32> + // CHECK: %[[BROADCAST0:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> : (tensor<4x1xf32>) -> tensor<4x4x1xf32> // CHECK: %[[RESHAPE0:.*]] = mhlo.reshape %[[BROADCAST0]] : (tensor<4x4x1xf32>) -> tensor<4x4xf32> - // CHECK: %[[UPDATED_ARG0:.*]] = "mhlo.broadcast_in_dim"(%[[RESHAPE0]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32> + // CHECK: %[[UPDATED_ARG0:.*]] = "mhlo.broadcast_in_dim"(%[[RESHAPE0]]) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<4x4xf32>) -> tensor<4x4x4xf32> // CHECK: %[[RESHAPE1:.*]] = mhlo.reshape %arg1 : (tensor<4x1x4xf32>) -> tensor<4x4xf32> - // CHECK: %[[UPDATED_ARG1:.*]] = "mhlo.broadcast_in_dim"(%[[RESHAPE1]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32> + // CHECK: %[[UPDATED_ARG1:.*]] = "mhlo.broadcast_in_dim"(%[[RESHAPE1]]) <{broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>}> : (tensor<4x4xf32>) -> tensor<4x4x4xf32> // CHECK: %[[RESULT:.*]] = mhlo.atan2 %[[UPDATED_ARG0]], %[[UPDATED_ARG1]] : tensor<4x4x4xf32> // CHECK: return %[[RESULT]] : tensor<4x4x4xf32> @@ -228,18 +228,18 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xi32>, %[[ARG1:.*]]: tensor<3xf32>, %[[ARG2:.*]]: tensor) func.func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tensor) -> tensor<3x3xf32> { - // CHECK: %[[DEFAULT:.*]] = "mhlo.broadcast_in_dim"(%[[ARG2]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<3x3xf32> + // CHECK: %[[DEFAULT:.*]] = "mhlo.broadcast_in_dim"(%[[ARG2]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<3x3xf32> - // CHECK: %[[RESULT:.*]] = "mhlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) ({ - // CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): - // CHECK: mhlo.return %[[ARG4]] : tensor - // CHECK: }) + // CHECK: %[[RESULT:.*]] = "mhlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) // CHECK-SAME: indices_are_sorted = false // CHECK-SAME: scatter_dimension_numbers // CHECK-SAME: inserted_window_dims = [0, 1] // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] // CHECK-SAME: index_vector_dim = 1 // CHECK-SAME: unique_indices = false + // CHECK-NEXT: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): + // CHECK: mhlo.return %[[ARG4]] : tensor + // CHECK: }) // CHECK-SAME: (tensor<3x3xf32>, tensor<3x2xi32>, tensor<3xf32>) -> tensor<3x3xf32> // return %[[RESULT]] : tensor<3x3xf32> @@ -332,7 +332,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func.func @set_dynamic_dimension_size(%input: tensor<4xf32>, %size: tensor) -> tensor { %dimension = "tf.Const"() { value = dense<0> : tensor } : () -> tensor // CHECK: mhlo.set_dimension_size - // CHECK-SAME: {dimension = 0 : i64} : (tensor<4xf32>, tensor) -> tensor> + // CHECK-SAME: <{dimension = 0 : i64}> : (tensor<4xf32>, tensor) -> tensor> %0 = "tf.XlaSetDynamicDimensionSize"(%input, %dimension, %size) : (tensor<4xf32>, tensor, tensor) -> tensor func.return %0 : tensor } @@ -469,7 +469,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func.func @bounds_propagation(%input: tensor<4xf32>, %size: tensor) -> tensor { %dimension = "tf.Const"() { value = dense<0> : tensor } : () -> tensor // CHECK: %[[BOUNDED:.*]] = "mhlo.set_dimension_size" - // CHECK-SAME: {dimension = 0 : i64} : (tensor<4xf32>, tensor) -> tensor> + // CHECK-SAME: <{dimension = 0 : i64}> : (tensor<4xf32>, tensor) -> tensor> %0 = "tf.XlaSetDynamicDimensionSize"(%input, %dimension, %size) : (tensor<4xf32>, tensor, tensor) -> tensor %axis = "tf.Const"() { value = dense<0> : tensor<1xi32> } : () -> tensor<1xi32> @@ -487,7 +487,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func.func @bounds_propagation_skip_symbol_ref_ops(%input: tensor<4xf32>, %size: tensor) -> tensor { %dimension = "tf.Const"() { value = dense<0> : tensor } : () -> tensor // CHECK: %[[BOUNDED:.*]] = "mhlo.set_dimension_size" - // CHECK-SAME: {dimension = 0 : i64} : (tensor<4xf32>, tensor) -> tensor> + // CHECK-SAME: <{dimension = 0 : i64}> : (tensor<4xf32>, tensor) -> tensor> %0 = "tf.XlaSetDynamicDimensionSize"(%input, %dimension, %size) : (tensor<4xf32>, tensor, tensor) -> tensor // CHECK: %[[ORIGINAL:.*]] = tensor.cast %[[BOUNDED]] : tensor> to tensor @@ -538,14 +538,14 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-LABEL: fusedBatchNormV3_noTraining func.func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) func.return %0#0 : tensor<8x8x8x8xf32> } // CHECK-LABEL: fusedBatchNormV3_training func.func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) func.return %0#0 : tensor<8x8x8x8xf32> } @@ -555,7 +555,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> // CHECK: %[[scr1:.*]] = mhlo.rsqrt - // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> // CHECK: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> // CHECK: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> @@ -566,7 +566,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> // CHECK: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> - // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> // CHECK: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> // CHECK: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> @@ -589,7 +589,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func.func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>) { // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32> // CHECK: return %[[x_backprop]] // CHECK-SAME: tensor<8x8x8x8xf32> @@ -602,7 +602,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32> func.func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ({ + // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) <{ + // CHECK-SAME: padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + // CHECK-SAME }> ({ // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]] : (tensor, tensor) -> tensor // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor @@ -610,7 +612,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor - // CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> + // CHECK: }) : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32> %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { data_format = "NHWC", @@ -661,7 +663,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-LABEL: func @concat_v2 func.func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { - // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> + // CHECK: "mhlo.concatenate"({{.*}}) <{dimension = 0 : i64}> : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> func.return %1 : tensor<6x3xf32> diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index ed288a27fa7383..bb9ca266fc7abc 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -18,7 +18,7 @@ // CHECK-LABEL: fusedBatchNormV2_noTraining func.func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) func.return %0#0 : tensor<8x8x8x8xf32> } @@ -27,7 +27,7 @@ func.func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor // CHECK-LABEL: fusedBatchNormV2_training func.func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) // CHECK: mhlo.constant // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor) -> tensor<8xf32> @@ -38,7 +38,7 @@ func.func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8 // CHECK-LABEL: fusedBatchNormV3_noTraining func.func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) func.return %0#0 : tensor<8x8x8x8xf32> } @@ -49,7 +49,7 @@ func.func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor // CHECK-SAME: ([[X:%.*]]: tensor<8x8x8x8xbf16>, [[SCALE:%.*]]: tensor<8xf32>, [[OFFSET:%.*]]: tensor<8xf32>, [[MEAN:%.*]]: tensor<8xf32>, [[VARIANCE:%.*]]: tensor<8xf32>) func.func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) { // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[X]] : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - // CHECK: [[Y:%.*]] = "mhlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + // CHECK: [[Y:%.*]] = "mhlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) // CHECK: [[Y_CONVERT:%.*]] = mhlo.convert [[Y]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> // CHECK: [[DUMMY:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<0xf32> @@ -62,7 +62,7 @@ func.func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16 // CHECK-LABEL: fusedBatchNormV3_training func.func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) // CHECK: mhlo.constant // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor) -> tensor<8xf32> @@ -73,7 +73,7 @@ func.func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8 // CHECK-LABEL: func @fusedBatchNormV3_training_batchVariance func.func @fusedBatchNormV3_training_batchVariance(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8xf32> { - // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) // CHECK: return %[[VAR]] func.return %0#4 : tensor<8xf32> @@ -83,7 +83,7 @@ func.func @fusedBatchNormV3_training_batchVariance(%arg0: tensor<8x8x8x8xf32>, % // CHECK-LABEL: fusedBatchNormV3_training_exponentialAvgFactor func.func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { - // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 0.8 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) // CHECK: %[[FACTOR:.*]] = mhlo.constant dense<1.00195694> // CHECK: %[[CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[VAR]], %[[FACTOR]] @@ -117,7 +117,7 @@ func.func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, // CHECK-LABEL: fusedBatchNormV3_NCHW func.func @fusedBatchNormV3_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 1 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) func.return %0#0 : tensor<8x8x8x8xf32> } @@ -135,7 +135,7 @@ func.func @fusedBatchNormV3_NDHWC(%arg0: tensor<8x8x8x8x8xf32>, %arg1: tensor<8x // CHECK-LABEL: fusedBatchNormV3_noTraining_dynamic_supported func.func @fusedBatchNormV3_noTraining_dynamic_supported(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> (tensor) { - // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 1 : i64}> : (tensor, tensor, tensor, tensor, tensor) -> tensor %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) func.return %0#0 : tensor } @@ -169,7 +169,7 @@ func.func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tens // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> - // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> @@ -179,7 +179,7 @@ func.func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tens // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> - // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> @@ -202,7 +202,7 @@ func.func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tens func.func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> @@ -221,7 +221,7 @@ func.func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> - // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> @@ -231,7 +231,7 @@ func.func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> - // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> @@ -255,7 +255,7 @@ func.func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: te func.func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> @@ -283,7 +283,7 @@ func.func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8 func.func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> @@ -302,7 +302,7 @@ func.func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> - // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> @@ -312,7 +312,7 @@ func.func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> - // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> @@ -336,7 +336,7 @@ func.func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: te func.func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>) { // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32> // CHECK: return %[[x_backprop]] // CHECK-SAME: tensor<8x8x8x8xf32> @@ -365,7 +365,7 @@ func.func @fusedBatchNormGradV3_noTraining_mixed_precision(%arg0: tensor<8x8x8x8 func.func @fusedBatchNormGradV3_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> @@ -384,7 +384,7 @@ func.func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> - // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64> @@ -394,7 +394,7 @@ func.func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> - // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> @@ -416,7 +416,7 @@ func.func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg // CHECK-LABEL: fusedBatchNormGradV3_Training_NCHW func.func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: %{{.*}} = "mhlo.batch_norm_grad"(%{{.*}}, %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: %{{.*}} = "mhlo.batch_norm_grad"(%{{.*}}, %arg2, %arg3, %arg4, %[[grad]]) <{epsilon = 1.000000e-03 : f32, feature_index = 1 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) func.return %0#0 : tensor<8x8x8x8xf32> } @@ -524,8 +524,8 @@ func.func @clip_dynamic(%arg0 : tensor, %arg1 : tensor, %arg2 : te // CHECK-LABEL: @clip_static_broadcast func.func @clip_static_broadcast(%arg0 : tensor<5xf32>, %arg1 : tensor, %arg2 : tensor) -> tensor<5xf32> { // CHECK-DAG: [[SHPIDX:%.+]] = mhlo.constant dense<5> - // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} - // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> + // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> // CHECK-DAG: [[CLAMP:%.+]] = mhlo.clamp [[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]] %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<5xf32>, tensor, tensor) -> tensor<5xf32> @@ -538,8 +538,8 @@ func.func @clip_static_broadcast(%arg0 : tensor<5xf32>, %arg1 : tensor, %ar func.func @clip_dynamic_broadcast(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { // CHECK: [[SHP:%.+]] = shape.shape_of %arg0 // CHECK: [[SHPIDX:%.+]] = arith.index_cast [[SHP]] : tensor<1xindex> to tensor<1xi32> - // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} - // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> + // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> // CHECK-DAG: [[CLAMP:%.+]] = mhlo.clamp [[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]] %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor @@ -557,11 +557,11 @@ func.func @clip_dynamic_broadcast(%arg0 : tensor, %arg1 : tensor, %a // CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32> func.func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { // CHECK: %[[RS:.*]] = mhlo.reshape %[[ARG]] : (tensor<4x3x4x3xf32>) -> tensor<12x12xf32> - // CHECK-DAG: %[[IOTA0:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<12x12xi32> - // CHECK-DAG: %[[IOTA1:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<12x12xi32> + // CHECK-DAG: %[[IOTA0:.*]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<12x12xi32> + // CHECK-DAG: %[[IOTA1:.*]] = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<12x12xi32> // CHECK-DAG: %[[COMP:.*]] = mhlo.compare EQ, %[[IOTA0]], %[[IOTA1]], NOTYPE : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1> // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK-DAG: %[[ZERO_MAT:.*]] = "mhlo.broadcast"(%[[ZERO]]) {broadcast_sizes = dense<12> : tensor<2xi64>} : (tensor) -> tensor<12x12xf32> + // CHECK-DAG: %[[ZERO_MAT:.*]] = "mhlo.broadcast"(%[[ZERO]]) <{broadcast_sizes = dense<12> : tensor<2xi64>}> : (tensor) -> tensor<12x12xf32> // CHECK-DAG: %[[SEL:.*]] = mhlo.select %[[COMP]], %[[RS]], %[[ZERO_MAT]] : tensor<12x12xi1>, tensor<12x12xf32> // CHECK-DAG: %[[RED:.*]] = mhlo.reduce(%[[SEL]] init: %[[ZERO]]) applies mhlo.add across dimensions = [0] : (tensor<12x12xf32>, tensor) -> tensor<12xf32> // CHECK-DAG: %[[RES:.*]] = mhlo.reshape %[[RED]] : (tensor<12xf32>) -> tensor<4x3xf32> @@ -581,22 +581,22 @@ func.func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { func.func @matrix_diag_part(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { // CHECK-DAG: %[[V0:.*]] = mhlo.constant dense<42> : tensor // CHECK-DAG: %[[V1:.*]] = mhlo.constant dense<[-10, 11]> : tensor<2xi32> - // CHECK-DAG: %[[V2:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<1x22x128xi32> - // CHECK-DAG: %[[V3:.*]] = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V2:.*]] = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V3:.*]] = "mhlo.iota"() <{iota_dimension = 2 : i64}> : () -> tensor<1x22x128xi32> // CHECK-DAG: %[[V4:.*]] = mhlo.constant dense<0> : tensor - // CHECK-DAG: %[[V5:.*]] = "mhlo.broadcast"(%[[V4]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V5:.*]] = "mhlo.broadcast"(%[[V4]]) <{broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>}> : (tensor) -> tensor<1x22x128xi32> // CHECK-DAG: %[[V6:.*]] = mhlo.constant dense : tensor - // CHECK-DAG: %[[V7:.*]] = "mhlo.broadcast"(%[[V6]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi1> + // CHECK-DAG: %[[V7:.*]] = "mhlo.broadcast"(%[[V6]]) <{broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>}> : (tensor) -> tensor<1x22x128xi1> // CHECK-DAG: %[[V8:.*]] = mhlo.constant dense : tensor - // CHECK-DAG: %[[V9:.*]] = "mhlo.broadcast"(%[[V8]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi1> + // CHECK-DAG: %[[V9:.*]] = "mhlo.broadcast"(%[[V8]]) <{broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>}> : (tensor) -> tensor<1x22x128xi1> // CHECK-DAG: %[[V10:.*]] = mhlo.constant dense<11> : tensor - // CHECK-DAG: %[[V11:.*]] = "mhlo.broadcast"(%[[V10]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V11:.*]] = "mhlo.broadcast"(%[[V10]]) <{broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>}> : (tensor) -> tensor<1x22x128xi32> // CHECK-DAG: %[[V12:.*]] = mhlo.constant dense<140> : tensor - // CHECK-DAG: %[[V13:.*]] = "mhlo.broadcast"(%[[V12]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V13:.*]] = "mhlo.broadcast"(%[[V12]]) <{broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>}> : (tensor) -> tensor<1x22x128xi32> // CHECK-DAG: %[[V14:.*]] = mhlo.constant dense<128> : tensor - // CHECK-DAG: %[[V15:.*]] = "mhlo.broadcast"(%[[V14]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V15:.*]] = "mhlo.broadcast"(%[[V14]]) <{broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>}> : (tensor) -> tensor<1x22x128xi32> // CHECK-DAG: %[[V16:.*]] = mhlo.constant dense<128> : tensor - // CHECK-DAG: %[[V17:.*]] = "mhlo.broadcast"(%[[V16]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V17:.*]] = "mhlo.broadcast"(%[[V16]]) <{broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>}> : (tensor) -> tensor<1x22x128xi32> // CHECK-DAG: %[[V18:.*]] = mhlo.subtract %[[V11]], %[[V2]] : tensor<1x22x128xi32> // CHECK-DAG: %[[V19:.*]] = mhlo.negate %[[V18]] : tensor<1x22x128xi32> // CHECK-DAG: %[[V20:.*]] = mhlo.minimum %[[V18]], %[[V5]] : tensor<1x22x128xi32> @@ -621,10 +621,10 @@ func.func @matrix_diag_part(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32 // CHECK-DAG: %[[V39:.*]] = mhlo.and %[[V37]], %[[V38]] : tensor<1x22x128xi1> // CHECK-DAG: %[[V40:.*]] = mhlo.and %[[V36]], %[[V39]] : tensor<1x22x128xi1> // CHECK-DAG: %[[V41:.*]] = mhlo.reshape %[[V40]] : (tensor<1x22x128xi1>) -> tensor<22x128xi1> - // CHECK-DAG: %[[V42:.*]] = "mhlo.concatenate"(%[[V33]], %[[V32]]) {dimension = 0 : i64} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32> - // CHECK-DAG: %[[V43:.*]] = "mhlo.gather"(%[[ARG]], %[[V42]]) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[7, 1, 1]> : tensor<3xi64>} : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> - // CHECK-DAG: %[[V44:.*]] = "mhlo.broadcast"(%[[V41]]) {broadcast_sizes = dense<7> : tensor<1xi64>} : (tensor<22x128xi1>) -> tensor<7x22x128xi1> - // CHECK-DAG: %[[V45:.*]] = "mhlo.broadcast"(%[[V0]]) {broadcast_sizes = dense<[7, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<7x22x128xi32> + // CHECK-DAG: %[[V42:.*]] = "mhlo.concatenate"(%[[V33]], %[[V32]]) <{dimension = 0 : i64}> : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32> + // CHECK-DAG: %[[V43:.*]] = "mhlo.gather"(%[[ARG]], %[[V42]]) <{dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[7, 1, 1]> : tensor<3xi64>}> : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> + // CHECK-DAG: %[[V44:.*]] = "mhlo.broadcast"(%[[V41]]) <{broadcast_sizes = dense<7> : tensor<1xi64>}> : (tensor<22x128xi1>) -> tensor<7x22x128xi1> + // CHECK-DAG: %[[V45:.*]] = "mhlo.broadcast"(%[[V0]]) <{broadcast_sizes = dense<[7, 22, 128]> : tensor<3xi64>}> : (tensor) -> tensor<7x22x128xi32> // CHECK: %[[V46:.*]] = mhlo.select %[[V44]], %[[V43]], %[[V45]] : tensor<7x22x128xi1>, tensor<7x22x128xi32> // CHECK: return %[[V46]] : tensor<7x22x128xi32> %0 = mhlo.constant dense<42> : tensor // padding value @@ -670,7 +670,7 @@ func.func @matrix_diag_part_align_ll(%arg0: tensor<7x140x128xi32>) -> tensor<7x2 T = i32, align = "LEFT_LEFT" } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> // CHECK: %[[false:.*]] = mhlo.constant dense : tensor - // CHECK: %[[b_false:.*]] = "mhlo.broadcast"(%[[false]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi1> + // CHECK: %[[b_false:.*]] = "mhlo.broadcast"(%[[false]]) <{broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>}> : (tensor) -> tensor<1x22x128xi1> // CHECK: %{{[0-9]*}} = mhlo.select %[[b_false]], %{{[0-9]*}}, %{{[0-9]*}} : tensor<1x22x128xi1>, tensor<1x22x128xi32> func.return %2: tensor<7x22x128xi32> } @@ -713,7 +713,7 @@ func.func @matrix_diag_part_align_rr(%arg0: tensor<7x140x128xi32>) -> tensor<7x2 T = i32, align = "RIGHT_RIGHT" } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> // CHECK: %[[true:.*]] = mhlo.constant dense : tensor - // CHECK: %[[b_true:.*]] = "mhlo.broadcast"(%[[true]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi1> + // CHECK: %[[b_true:.*]] = "mhlo.broadcast"(%[[true]]) <{broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>}> : (tensor) -> tensor<1x22x128xi1> // CHECK: %{{[0-9]*}} = mhlo.select %[[b_true]], %{{[0-9]*}}, %{{[0-9]*}} : tensor<1x22x128xi1>, tensor<1x22x128xi32> func.return %2: tensor<7x22x128xi32> } @@ -1009,7 +1009,7 @@ func.func @floormod_dynamic_broadcast_denominator_(%arg0: tensor, %arg1 // CHECK-LABEL: @ones_like // CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>) func.func @ones_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { - // CHECK: %[[RES:.*]] = "chlo.constant_like"(%[[ARG]]) {value = 1.0{{.*}}} + // CHECK: %[[RES:.*]] = "chlo.constant_like"(%[[ARG]]) <{value = 1.0{{.*}}}> // CHECK: return %[[RES]] %0 = "tf.OnesLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> func.return %0 : tensor<2x?xf32> @@ -1024,7 +1024,7 @@ func.func @ones_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { // CHECK-LABEL: @zeros_like // CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>) func.func @zeros_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { - // CHECK: %[[RES:.*]] = "chlo.constant_like"(%[[ARG]]) {value = 0.0{{.*}}} + // CHECK: %[[RES:.*]] = "chlo.constant_like"(%[[ARG]]) <{value = 0.0{{.*}}}> // CHECK: return %[[RES]] %0 = "tf.ZerosLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> func.return %0 : tensor<2x?xf32> @@ -1086,7 +1086,7 @@ func.func @real(%arg0: tensor<3xcomplex>) -> tensor<3xf32> { // CHECK-LABEL: func @concat_v2 func.func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { - // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> + // CHECK: "mhlo.concatenate"({{.*}}) <{dimension = 0 : i64}> : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> func.return %1 : tensor<6x3xf32> @@ -1096,7 +1096,7 @@ func.func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6 // CHECK-LABEL: func @concat_v2_neg_axis func.func @concat_v2_neg_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { - // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> + // CHECK: "mhlo.concatenate"({{.*}}) <{dimension = 0 : i64}> : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> %axis = "tf.Const"() { value = dense<-2> : tensor } : () -> tensor %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> @@ -1107,7 +1107,7 @@ func.func @concat_v2_neg_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> // CHECK-LABEL: func @concat_v2_1d_axis func.func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { - // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> + // CHECK: "mhlo.concatenate"({{.*}}) <{dimension = 1 : i64}> : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> %axis = "tf.Const"() { value = dense<[1]> : tensor<1xi64> } : () -> tensor<1xi64> %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi64>) -> tensor<3x6xf32> @@ -1132,7 +1132,7 @@ func.func @concat_v2_non_const_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf3 // CHECK-LABEL: func @padv2_1D func.func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor<6xf32> { %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64> - // CHECK: "mhlo.pad"(%arg0, %arg1) { + // CHECK: "mhlo.pad"(%arg0, %arg1) <{ // CHECK-SAME: edge_padding_high = dense<2> : tensor<1xi64>, // CHECK-SAME: edge_padding_low = dense<1> : tensor<1xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<1xi64> @@ -1145,7 +1145,7 @@ func.func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor<6xf32> { // CHECK-LABEL: func @padv2_2D func.func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi64> } : () -> tensor<2x2xi64> - // CHECK: "mhlo.pad"(%arg0, %arg1) { + // CHECK: "mhlo.pad"(%arg0, %arg1) <{ // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> @@ -1158,7 +1158,7 @@ func.func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf3 // CHECK-LABEL: func @padv2_i32_paddings func.func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32> - // CHECK: "mhlo.pad"(%arg0, %arg1) { + // CHECK: "mhlo.pad"(%arg0, %arg1) <{ // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> @@ -1170,10 +1170,10 @@ func.func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor) -> ten // CHECK-LABEL: func @padv2_dynamic func.func @padv2_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor<1x2xi64>) -> tensor { - // CHECK: "mhlo.transpose"({{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x2xi64>) -> tensor<2x1xi64> + // CHECK: "mhlo.transpose"({{.*}}) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<1x2xi64>) -> tensor<2x1xi64> // CHECK: mhlo.reshape {{.*}} : (tensor<2x1xi64>) -> tensor<2xi64> - // CHECK: "mhlo.slice"({{.*}}) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> - // CHECK: "mhlo.slice"({{.*}}) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: "mhlo.slice"({{.*}}) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: "mhlo.slice"({{.*}}) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi64>) -> tensor<1xi64> // CHECK: mhlo.dynamic_pad {{.*}} : (tensor, tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor %1 = "tf.PadV2"(%arg0, %arg2, %arg1) : (tensor, tensor<1x2xi64>, tensor) -> tensor func.return %1 : tensor @@ -1237,7 +1237,7 @@ func.func @checkNumerics(%arg0: tensor<1xf32>) -> tensor<1xf32> { // CHECK-LABEL: func @infeed_dequeue_tuple func.func @infeed_dequeue_tuple() -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) { // CHECK: [[TOKEN:%.*]] = mhlo.create_token : !mhlo.token -// CHECK: [[INFEED:%.*]]:3 = "mhlo.infeed"([[TOKEN]]) {infeed_config = ""{{.*}}} : (!mhlo.token) -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>, !mhlo.token) +// CHECK: [[INFEED:%.*]]:3 = "mhlo.infeed"([[TOKEN]]) <{infeed_config = ""{{.*}}}> : (!mhlo.token) -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>, !mhlo.token) // CHECK: return [[INFEED]]#0, [[INFEED]]#1 %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) func.return %0#0, %0#1 : tensor<1x8x4x4xi32>, tensor<1x100x1xf32> @@ -1338,7 +1338,7 @@ func.func @matmul_notranspose(%a: tensor<5x7xf32>, %b: tensor<7x11xf32>) -> tens // CHECK-LABEL: matmul_transpose_b // CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<11x7xf32>) func.func @matmul_transpose_b(%a: tensor<5x7xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { - // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>} + // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) <{permutation = dense<[1, 0]> : tensor<2xi64>}> // CHECK: "mhlo.dot"(%[[A]], %[[UPDATED_B]]) %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = true} : (tensor<5x7xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> @@ -1350,8 +1350,8 @@ func.func @matmul_transpose_b(%a: tensor<5x7xf32>, %b: tensor<11x7xf32>) -> tens // CHECK-LABEL: matmul_transpose_both // CHECK-SAME: (%[[A:.*]]: tensor<7x5xf32>, %[[B:.*]]: tensor<11x7xf32>) func.func @matmul_transpose_both(%a: tensor<7x5xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { - // CHECK: %[[UPDATED_A:.*]] = "mhlo.transpose"(%[[A]]) {permutation = dense<[1, 0]> : tensor<2xi64>} - // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>} + // CHECK: %[[UPDATED_A:.*]] = "mhlo.transpose"(%[[A]]) <{permutation = dense<[1, 0]> : tensor<2xi64>}> + // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) <{permutation = dense<[1, 0]> : tensor<2xi64>}> // CHECK: "mhlo.dot"(%[[UPDATED_A]], %[[UPDATED_B]]) %0 = "tf.MatMul"(%a, %b) {transpose_a = true, transpose_b = true} : (tensor<7x5xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> @@ -1419,9 +1419,9 @@ func.func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4 func.func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]]) + // CHECK: <{window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>}> // CHECK: mhlo.maximum // CHECK: mhlo.return - // CHECK: {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> func.return %0 : tensor<2x3x5x7xi32> @@ -1445,9 +1445,9 @@ func.func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7x func.func @maxpool_3d_valid_padding(%arg0: tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> { // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]]) + // CHECK: <{window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 4, 4, 1]> : tensor<5xi64>}> // CHECK: mhlo.maximum // CHECK: mhlo.return - // CHECK: {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 4, 4, 1]> : tensor<5xi64>} %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> func.return %0 : tensor<2x8x3x5x7xf32> @@ -1485,7 +1485,7 @@ func.func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x // CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32> func.func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ({ + // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) <{window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({ // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor, tensor) -> tensor // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor @@ -1493,7 +1493,7 @@ func.func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_outpu // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor - // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> + // CHECK: }) : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32> %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { data_format = "NHWC", @@ -1510,7 +1510,7 @@ func.func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_outpu // CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32> func.func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ({ + // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) <{window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>}> ({ // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor, tensor) -> tensor // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor @@ -1518,7 +1518,7 @@ func.func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_ // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor - // CHECK: }) {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor) -> tensor<10x8x24x24x64xf32> + // CHECK: }) : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor) -> tensor<10x8x24x24x64xf32> // CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32> %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> func.return %result : tensor<10x8x24x24x64xf32> @@ -1555,11 +1555,11 @@ func.func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_out // CHECK-LABEL:one_hot func.func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tensor) -> tensor<3x5xf32> { - // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x5xi32> - // CHECK: %[[BCAST_ARG0:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<3x5xi32> + // CHECK: %[[IOTA:.*]] = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<3x5xi32> + // CHECK: %[[BCAST_ARG0:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<3xi32>) -> tensor<3x5xi32> // CHECK: %[[COMPARE:.*]] = mhlo.compare EQ, %[[BCAST_ARG0]], %[[IOTA]], NOTYPE : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> - // CHECK: %[[ON_VALUE:.*]] = "mhlo.broadcast"(%arg1) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor) -> tensor<3x5xf32> - // CHECK: %[[OFF_VALUE:.*]] = "mhlo.broadcast"(%arg2) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor) -> tensor<3x5xf32> + // CHECK: %[[ON_VALUE:.*]] = "mhlo.broadcast"(%arg1) <{broadcast_sizes = dense<[3, 5]> : tensor<2xi64>}> : (tensor) -> tensor<3x5xf32> + // CHECK: %[[OFF_VALUE:.*]] = "mhlo.broadcast"(%arg2) <{broadcast_sizes = dense<[3, 5]> : tensor<2xi64>}> : (tensor) -> tensor<3x5xf32> // CHECK: %[[RESULT:.*]] = mhlo.select %[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]] : tensor<3x5xi1>, tensor<3x5xf32> // CHECK: return %[[RESULT]] : tensor<3x5xf32> %depth = "tf.Const"() { value = dense<5> : tensor } : () -> tensor @@ -1577,7 +1577,7 @@ func.func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: // CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>) func.func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () { // CHECK: [[TOKEN:%.*]] = mhlo.create_token : !mhlo.token -// CHECK: "mhlo.outfeed"([[VAL_0]], [[VAL_1]], [[TOKEN]]) {outfeed_config = ""} : (tensor<3xi32>, tensor<4xf32>, !mhlo.token) -> !mhlo.token +// CHECK: "mhlo.outfeed"([[VAL_0]], [[VAL_1]], [[TOKEN]]) <{outfeed_config = ""}> : (tensor<3xi32>, tensor<4xf32>, !mhlo.token) -> !mhlo.token "tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> () func.return } @@ -1592,7 +1592,7 @@ func.func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) func.func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // CHECK: mhlo.reshape {{.*}} : (tensor<2xi32>) -> tensor<1x2xi32> // CHECK: mhlo.reshape {{.*}} : (tensor<2xi32>) -> tensor<1x2xi32> - // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> + // CHECK: "mhlo.concatenate"({{.*}}) <{dimension = 0 : i64}> : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> %0 = "tf.Pack"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> func.return %0 : tensor<2x2xi32> @@ -1689,7 +1689,7 @@ func.func @callee() { func.func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} + // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) <{dimensions = dense<0> : tensor<1xi64>}> %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> // CHECK: return [[VAL]] : tensor<5xi32> @@ -1702,7 +1702,7 @@ func.func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { func.func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>) - // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} + // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) <{dimensions = dense<0> : tensor<1xi64>}> %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32> // CHECK: return [[VAL]] : tensor<5xi32> @@ -1715,7 +1715,7 @@ func.func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { func.func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> { %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<1> : tensor<1xi64>} + // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) <{dimensions = dense<1> : tensor<1xi64>}> %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32> // CHECK: return [[VAL]] : tensor<5x5xi32> @@ -1762,7 +1762,7 @@ func.func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) - // CHECK-LABEL: func @elu func.func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> { - // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<1xf32>) -> tensor<1xf32> + // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) <{value = 0.000000e+00 : f32}> : (tensor<1xf32>) -> tensor<1xf32> // CHECK-DAG: %[[PRED:.*]] = mhlo.compare GT, %arg0, %[[ZERO]] // CHECK-DAG: %[[EXP:.*]] = mhlo.exponential_minus_one %arg0 // CHECK: %[[RESULT:.*]] = mhlo.select %[[PRED]], %arg0, %[[EXP]] @@ -1837,8 +1837,8 @@ func.func @relu6_unsigned(%arg0: tensor) -> tensor { // CHECK-LABEL: func @leaky_relu func.func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {tf.entry_function = {}} { - // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg0) {value = 2.000000e-01 : f32} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> - // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> + // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg0) <{value = 2.000000e-01 : f32}> : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> + // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) <{value = 0.000000e+00 : f32}> : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[ALPHA]] : tensor<1x4x4x3xf32> // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP]], %[[ZERO]], NOTYPE : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1> // CHECK-NEXT: %[[RES:.*]] = mhlo.select %[[CMP]], %[[INP]], %[[LEAKY]] : tensor<1x4x4x3xi1>, tensor<1x4x4x3xf32> @@ -1851,8 +1851,8 @@ func.func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attribu // CHECK-LABEL: func @leaky_relu_grad func.func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) -> tensor<1x4x4xf32> attributes {tf.entry_function = {}} { - // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg1) {value = 2.000000e-01 : f32} : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> - // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) {value = 0.000000e+00 : f32} : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> + // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg1) <{value = 2.000000e-01 : f32}> : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> + // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) <{value = 0.000000e+00 : f32}> : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] : tensor<1x4x4xf32> // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP:.*]], %[[ZERO]], NOTYPE : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1> // CHECK-NEXT: %[[RES:.*]] = mhlo.select %[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]] : tensor<1x4x4xi1>, tensor<1x4x4xf32> @@ -1865,7 +1865,7 @@ func.func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) - // CHECK-LABEL: func @softsign func.func @softsign(%arg0: tensor<4x10xf32>) -> tensor<4x10xf32> { - // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} : (tensor<4x10xf32>) -> tensor<4x10xf32> + // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) <{value = 1.000000e+00 : f32}> : (tensor<4x10xf32>) -> tensor<4x10xf32> // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32> // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[ONE]], %[[ABS]] : tensor<4x10xf32> // CHECK-NEXT: %[[DIV:.*]] = mhlo.divide %{{.*}}, %[[ADD]] : tensor<4x10xf32> @@ -1903,7 +1903,7 @@ func.func @Roll_0D(%arg0: tensor<512xi32>, %shift: tensor) -> tensor<512xi3 // CHECK: %[[T1:.+]] = mhlo.remainder %arg1, %[[AXIS_SIZE]] : tensor // CHECK: %[[T2:.+]] = mhlo.add %[[T1]], %[[AXIS_SIZE]] : tensor // CHECK: %[[T3:.+]] = mhlo.remainder %[[T2]], %[[AXIS_SIZE]] : tensor - // CHECK: %[[CONCAT:.+]] = "mhlo.concatenate"(%arg0, %arg0) {dimension = 0 : i64} + // CHECK: %[[CONCAT:.+]] = "mhlo.concatenate"(%arg0, %arg0) <{dimension = 0 : i64}> // CHECK: %[[OFFSET:.+]] = mhlo.subtract %[[AXIS_SIZE]], %[[T3]] : tensor // CHECK: "mhlo.dynamic_slice"(%[[CONCAT]], %[[OFFSET]]) // CHECK-SAME: {slice_sizes = dense<512> : tensor<1xi64>} @@ -1920,7 +1920,7 @@ func.func @Roll_0D(%arg0: tensor<512xi32>, %shift: tensor) -> tensor<512xi3 // CHECK-LABEL: func @select_batch_static func.func @select_batch_static(%arg0: tensor<2xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { - // CHECK: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %{{.*}}) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<2xi1>, tensor<3xindex>) -> tensor<2x6x8xi1> + // CHECK: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %{{.*}}) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<2xi1>, tensor<3xindex>) -> tensor<2x6x8xi1> // CHECK: mhlo.select %[[BCAST]], %arg1, %arg2 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> func.return %0: tensor<2x6x8xi32> @@ -1958,7 +1958,7 @@ func.func @select_batch_dynamic_r1(%arg0: tensor, %arg1: tensor // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming_all %[[SHAPEEQ1]], %[[SHAPEEQ2]] // CHECK-NEXT: %[[ASSUMING:.*]] = shape.assuming %[[SHAPEEQ]] -> (tensor) { // CHECK-NEXT: %[[SHAPE1E:.*]] = shape.to_extent_tensor %[[SHAPE1]] : tensor<3xindex> -> tensor<3xindex> - // CHECK-NEXT: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE1E]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<3xindex>) -> tensor + // CHECK-NEXT: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE1E]]) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor, tensor<3xindex>) -> tensor // CHECK-NEXT: %[[SELECT:.*]] = mhlo.select %[[BCAST]], %arg1, %arg2 : tensor, tensor // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor @@ -2026,7 +2026,7 @@ func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32> // CHECK-LABEL: func @fft_1D func.func @fft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { - // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo} : (tensor<8xcomplex> + // CHECK: "mhlo.fft"(%arg0) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<8xcomplex> %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> func.return %0 : tensor<8xcomplex> } @@ -2035,7 +2035,7 @@ func.func @fft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { // CHECK-LABEL: func @ifft_1D func.func @ifft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { - // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo} : (tensor<8xcomplex> + // CHECK: "mhlo.fft"(%arg0) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<8xcomplex> %0 = "tf.IFFT"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> func.return %0 : tensor<8xcomplex> } @@ -2045,7 +2045,7 @@ func.func @ifft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { // CHECK-LABEL: func @rfft_1D func.func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<5xcomplex> { %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo} : (tensor<8xf32> + // CHECK: "mhlo.fft"(%arg0) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<8xf32> %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<8xf32>, tensor<1xi32>) -> tensor<5xcomplex> func.return %0 : tensor<5xcomplex> } @@ -2055,8 +2055,8 @@ func.func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<5xcomplex> { // CHECK-LABEL: func @rfft_1D_padded func.func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<5xcomplex> { %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: %[[PADDED:.*]] = "mhlo.pad"(%arg0, %{{.*}}) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<7xf32>, tensor) -> tensor<8xf32> - // CHECK: "mhlo.fft"(%[[PADDED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo} : (tensor<8xf32> + // CHECK: %[[PADDED:.*]] = "mhlo.pad"(%arg0, %{{.*}}) <{edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<7xf32>, tensor) -> tensor<8xf32> + // CHECK: "mhlo.fft"(%[[PADDED]]) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<8xf32> %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<5xcomplex> func.return %0 : tensor<5xcomplex> } @@ -2066,8 +2066,8 @@ func.func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<5xcomplex> { // CHECK-LABEL: func @rfft_1D_sliced func.func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x5xcomplex> { %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[2, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x9xf32>) -> tensor<2x8xf32> - // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo} : (tensor<2x8xf32> + // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) <{limit_indices = dense<[2, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x9xf32>) -> tensor<2x8xf32> + // CHECK: "mhlo.fft"(%[[SLICED]]) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<2x8xf32> %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x5xcomplex> func.return %0 : tensor<2x5xcomplex> } @@ -2077,8 +2077,8 @@ func.func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x5xcomplex> { // CHECK-LABEL: func @irfft_1D func.func @irfft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xf32> { %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<5> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<8xcomplex>) -> tensor<5xcomplex> - // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo} : (tensor<5xcomplex> + // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) <{limit_indices = dense<5> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<8xcomplex>) -> tensor<5xcomplex> + // CHECK: "mhlo.fft"(%[[SLICED]]) <{fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<5xcomplex> %0 = "tf.IRFFT"(%arg0, %fftlength) : (tensor<8xcomplex>, tensor<1xi32>) -> tensor<8xf32> func.return %0 : tensor<8xf32> } @@ -2816,7 +2816,7 @@ func.func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK-LABEL: slice_i32_consts func.func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor - // CHECK: "mhlo.dynamic_slice"(%arg0, %[[START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> + // CHECK: "mhlo.dynamic_slice"(%arg0, %[[START]]) <{slice_sizes = dense<2> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<2xi32> %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> @@ -2828,7 +2828,7 @@ func.func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK-LABEL: slice_constant_start_negative_one_size func.func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> { // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<3xi32> + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]]) <{slice_sizes = dense<3> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<3xi32> // CHECK: return %[[RESULT]] : tensor<3xi32> %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) %sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>) @@ -2867,7 +2867,7 @@ func.func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> // CHECK: %[[RESHAPED_START2:.*]] = mhlo.reshape %[[SLICED_START2]] : (tensor<1xi64>) -> tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> // CHECK: return %[[RESULT]] : tensor<1x4xi32> %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> @@ -2961,7 +2961,7 @@ func.func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2 %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - // CHECK: "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1]> : tensor<2xi64>}> // CHECK: mhlo.slice // CHECK-DAG-SAME: start_indices = dense<[0, 1]> @@ -3581,7 +3581,7 @@ func.func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { // CHECK-LABEL: func @tile_by_reshape func.func @tile_by_reshape(%arg0: tensor<4x8xf32>) -> tensor<28x24xf32> { - // CHECK: %[[BROADCASTED:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<4x8xf32>) -> tensor<7x4x3x8xf32> + // CHECK: %[[BROADCASTED:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>}> : (tensor<4x8xf32>) -> tensor<7x4x3x8xf32> // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[BROADCASTED]] : (tensor<7x4x3x8xf32>) -> tensor<28x24xf32> // CHECK: return %[[RESULT]] : tensor<28x24xf32> %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> @@ -3593,7 +3593,7 @@ func.func @tile_by_reshape(%arg0: tensor<4x8xf32>) -> tensor<28x24xf32> { // CHECK-LABEL: func @tile_just_broadcast func.func @tile_just_broadcast(%arg0: tensor<1x1xf32>) -> tensor<7x3xf32> { - // CHECK: %[[RESULT:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<7x3xf32> + // CHECK: %[[RESULT:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<7x3xf32> // CHECK: return %[[RESULT]] : tensor<7x3xf32> %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> %0 = "tf.Tile"(%arg0, %multiples) : (tensor<1x1xf32>, tensor<2xi64>) -> tensor<7x3xf32> @@ -3607,7 +3607,7 @@ func.func @tile_dynamic_shape(%arg0: tensor) -> tensor { %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi32> } : () -> tensor<2xi32> // CHECK: tensor.dim {{.*}} : tensor // CHECK: tensor.from_elements {{.*}} : tensor<4xindex> - // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}) <{broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>}> : (tensor, tensor<4xindex>) -> tensor // CHECK: muli {{.*}} : index // CHECK: tensor.from_elements {{.*}} : tensor<2xindex> // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor, tensor<2xindex>) -> tensor @@ -3626,7 +3626,7 @@ func.func @argmax_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor< // CHECK: %[[INIT:.*]] = mhlo.constant dense<-9223372036854775808> : tensor // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi64> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<3x7xi32> + // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<3x7xi32> // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) // CHECK: (%[[ARG1:.*]]: tensor, %[[ARG3:.*]]: tensor) (%[[ARG2:.*]]: tensor, %[[ARG4:.*]]: tensor) // CHECK: %[[COMPARE:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor, tensor) -> tensor @@ -3649,7 +3649,7 @@ func.func @argmax_f32_input_i64_output_axis_1(%arg0: tensor<3x7xf32>) -> tensor< // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xf32> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<3x7xi64> + // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<3x7xi64> // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) // CHECK: return %[[REDUCE]]#1 : tensor<3xi64> %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor @@ -3664,7 +3664,7 @@ func.func @argmax_i1_input_i64_output_axis_1(%arg0: tensor<3x7xi1>) -> tensor<3x // CHECK-DAG: %[[INIT:.*]] = mhlo.constant dense : tensor // CHECK-DAG: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi1> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<3x7xi64> + // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<3x7xi64> // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) // CHECK: return %[[REDUCE]]#1 : tensor<3xi64> %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor @@ -3679,7 +3679,7 @@ func.func @argmax_dynamic_shape_input_output(%arg0: tensor<3x?xi32>) -> tensor : tensor // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x?xi32> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<3x?xi32> + // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<3x?xi32> // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) // CHECK: return %[[REDUCE]]#1 : tensor %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor @@ -3694,7 +3694,7 @@ func.func @argmax_dynamic_shape_input(%arg0: tensor<3x?xi32>) -> tensor<3xi32> { // CHECK-DAG: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor // CHECK-DAG: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x?xi32> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<3x?xi32> + // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<3x?xi32> // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) // CHECK: return %[[REDUCE]]#1 : tensor<3xi32> %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor @@ -3709,7 +3709,7 @@ func.func @argmin_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor< // CHECK: %[[INIT:.*]] = mhlo.constant dense<9223372036854775807> : tensor // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi64> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<3x7xi32> + // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<3x7xi32> // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) // CHECK: (%[[ARG1:.*]]: tensor, %[[ARG3:.*]]: tensor) (%[[ARG2:.*]]: tensor, %[[ARG4:.*]]: tensor) // CHECK: %[[COMPARE:.*]] = mhlo.compare LE, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor, tensor) -> tensor @@ -3763,7 +3763,7 @@ func.func @random_uniform_with_seeds(%arg0: tensor<4xi32>) -> tensor<32x12x12x64 // CHECK-NEXT: %1 = mhlo.constant dense<0.000000e+00> : tensor // CHECK-NEXT: %2 = mhlo.constant dense<1.000000e+00> : tensor // CHECK-NEXT: %3 = mhlo.convert %0 : (tensor<4xi32>) -> tensor<4xi64> - // CHECK-NEXT: %4 = "mhlo.rng"(%1, %2, %3) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<4xi64>) -> tensor<32x12x12x64xf32> + // CHECK-NEXT: %4 = "mhlo.rng"(%1, %2, %3) <{rng_distribution = #mhlo.rng_distribution}> : (tensor, tensor, tensor<4xi64>) -> tensor<32x12x12x64xf32> %cst = "tf.Const"() {value = dense<[32, 12, 12, 64]> : tensor<4xi32>} : () -> tensor<4xi32> %0 = "tf.RandomUniform"(%cst) {seed = 87654321 : i64, seed2 = 0 : i64} : (tensor<4xi32>) -> tensor<32x12x12x64xf32> // CHECK: return %4 : tensor<32x12x12x64xf32> @@ -3813,7 +3813,7 @@ func.func @range_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert %arg0 // CHECK-DAG: [[CONVERT_4:%.+]] = mhlo.convert %arg2 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT_4]] {broadcast_dimensions = array} @@ -3837,7 +3837,7 @@ func.func @range_int_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tens // CHECK-DAG: [[CEIL:%.+]] = mhlo.ceil [[DIV]] // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert [[CEIL]] // CHECK-DAG: [[RESHAPE:%.+]] = mhlo.reshape [[CONVERT_3]] - // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64} + // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) <{iota_dimension = 0 : i64}> // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert %arg0 // CHECK-DAG: [[CONVERT_4:%.+]] = mhlo.convert %arg2 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT_4]] {broadcast_dimensions = array} @@ -3859,7 +3859,7 @@ func.func @linspace_static(%arg0: tensor, %arg1: tensor) -> tensor<4xf // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = chlo.broadcast_subtract [[NUM_F32]], [[ONE]] // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]] // CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] - // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} + // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = array} // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = array} // CHECK: return [[LINSPACE]] @@ -3998,7 +3998,7 @@ func.func @conv_explicit_paddings(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor< // CHECK-LABEL: @conv2d_backprop_input_dynamic func.func @conv2d_backprop_input_dynamic(%filter: tensor<2x2x1x16xf32>, %out_backprop: tensor) -> tensor { - // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1]> : tensor<2xi64>}> // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f] // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} @@ -4023,7 +4023,7 @@ func.func @conv2d_backprop_input( %filter: tensor<3x3x1x32xf32>, %out_backprop: tensor<100x26x26x32xf32> ) -> tensor<100x28x28x1xf32> { - // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1]> : tensor<2xi64>}> // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f] // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} @@ -4073,7 +4073,7 @@ func.func @conv2d_backprop_input_grouped( // CHECK-LABEL: @conv3d_backprop_input func.func @conv3d_backprop_input(%filter: tensor<3x3x3x1x6xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { - // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, o, i]->[b, 0, 1, 2, f] // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} @@ -4184,7 +4184,7 @@ func.func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { // CHECK-LABEL: conv_dynamic func.func @conv_dynamic(%arg0: tensor, %arg1: tensor<3x3x3x16xf32>) -> tensor { // CHECK: "mhlo.dynamic_conv" - // CHECK-SAME: {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo], rhs_dilation = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[4, 5]> : tensor<2xi64>} : (tensor, tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor + // CHECK-SAME: <{batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo], rhs_dilation = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[4, 5]> : tensor<2xi64>}> : (tensor, tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor, tensor<3x3x3x16xf32>) -> tensor func.return %0 : tensor } @@ -4245,8 +4245,8 @@ func.func @split_not_match_static_split_dim_size(%input: tensor<4x?x4xf32>) -> ( // CHECK-LABEL: @split_match_and_split_into_two func.func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) { %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, 6]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32> - // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32> + // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[2, 6]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<2x6xf32> + // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<2x6xf32> %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) // CHECK: return %[[ONE]], %[[TWO]] func.return %0#0, %0#1 : tensor<2x6xf32>, tensor<2x6xf32> @@ -4258,9 +4258,9 @@ func.func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x // CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) func.func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) { %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> - // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> - // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> %0:3 = "tf.Split"(%cst, %input) : (tensor, tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] func.return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32> @@ -4312,9 +4312,9 @@ func.func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8x func.func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x1xf32> - // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> - // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x3xf32> + // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x1xf32> + // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x3xf32> %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] func.return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> @@ -4365,11 +4365,11 @@ func.func @assert(%arg0: tensor, %arg1: tensor<*xf32>) { // CHECK-LABEL: @unpack func.func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { - // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> + // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> // CHECK: %[[RES1:.*]] = mhlo.reshape %[[SLICE1]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> + // CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> // CHECK: %[[RES2:.*]] = mhlo.reshape %[[SLICE2]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> + // CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> // CHECK: %[[RES3:.*]] = mhlo.reshape %[[SLICE3]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) @@ -4405,19 +4405,19 @@ func.func @unpack_dynamic(%arg0: tensor) -> (tensor, tensor< func.func @unsorted_segment_sum(%data: tensor<8x16x64xf32>, %segment_ids : tensor<8x16xi32>) -> (tensor<4x64xf32>) { %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ZERO]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor) -> tensor<4x64xf32> - // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ({ - // CHECK: ^{{.*}}([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): - // CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] : tensor - // CHECK: mhlo.return [[ADD]] - // CHECK: indices_are_sorted = false, + // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ZERO]]) <{broadcast_sizes = dense<[4, 64]> : tensor<2xi64>}> : (tensor) -> tensor<4x64xf32> + // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) + // CHECK-SAME: indices_are_sorted = false, // CHECK-SAME: scatter_dimension_numbers = // CHECK-SAME: update_window_dims = [2] // CHECK-SAME: inserted_window_dims = [0] // CHECK-SAME: scatter_dims_to_operand_dims = [0] // CHECK-SAME: index_vector_dim = 2 // CHECK-SAME: unique_indices = false - // CHECK-SAME: (tensor<4x64xf32>, tensor<8x16xi32>, tensor<8x16x64xf32>) -> tensor<4x64xf32> + // CHECK: ^{{.*}}([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): + // CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] : tensor + // CHECK: mhlo.return [[ADD]] + // CHECK-NEXT: (tensor<4x64xf32>, tensor<8x16xi32>, tensor<8x16x64xf32>) -> tensor<4x64xf32> // CHECK: return [[SCATTER]] %0 = "tf.UnsortedSegmentSum"(%data, %segment_ids, %num_segments) : (tensor<8x16x64xf32>, tensor<8x16xi32>, tensor) -> (tensor<4x64xf32>) func.return %0: tensor<4x64xf32> @@ -4431,19 +4431,19 @@ func.func @unsorted_segment_sum(%data: tensor<8x16x64xf32>, %segment_ids : tenso func.func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor) -> (tensor<4x?xf32>) { %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor // CHECK: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ONE]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor) -> tensor<4x64xf32> - // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ({ - // CHECK: ^{{.*}}([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): - // CHECK: [[MUL:%.*]] = mhlo.multiply [[LHS]], [[RHS]] : tensor - // CHECK: mhlo.return [[MUL]] - // CHECK: indices_are_sorted = false + // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ONE]]) <{broadcast_sizes = dense<[4, 64]> : tensor<2xi64>}> : (tensor) -> tensor<4x64xf32> + // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) + // CHECK-SAME: indices_are_sorted = false // CHECK-SAME: scatter_dimension_numbers = // CHECK-SAME: update_window_dims = [2] // CHECK-SAME: inserted_window_dims = [0] // CHECK-SAME: scatter_dims_to_operand_dims = [0] // CHECK-SAME: index_vector_dim = 2 // CHECK-SAME: unique_indices = false - // CHECK-SAME: (tensor<4x64xf32>, tensor, tensor<8x?x64xf32>) -> tensor<4x?xf32> + // CHECK: ^{{.*}}([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): + // CHECK: [[MUL:%.*]] = mhlo.multiply [[LHS]], [[RHS]] : tensor + // CHECK: mhlo.return [[MUL]] + // CHECK-NEXT: (tensor<4x64xf32>, tensor, tensor<8x?x64xf32>) -> tensor<4x?xf32> // CHECK: return [[SCATTER]] %0 = "tf.UnsortedSegmentProd"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) func.return %0: tensor<4x?xf32> @@ -4496,7 +4496,7 @@ func.func @gatherNd_dynamic(%arg0: tensor, %arg1: tensor) // CHECK-LABEL: func @gatherNd_static func.func @gatherNd_static(%arg0: tensor<2x4x128xf32>, %arg1: tensor<2x1xi32>) -> tensor<2x4x128xf32> { - // CHECK: "mhlo.gather"({{.*}}) { + // CHECK: "mhlo.gather"({{.*}}) <{ // CHECK-SAME: dimension_numbers = // CHECK-SAME: offset_dims = [1, 2] // CHECK-SAME: collapsed_slice_dims = [0] @@ -4610,9 +4610,9 @@ func.func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024 %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) // CHECK: [[RESHAPE:%.*]] = mhlo.reshape %arg0 : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> - // CHECK: [[REVERSE:%.*]] = "mhlo.reverse"([[RESHAPE]]) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> + // CHECK: [[REVERSE:%.*]] = "mhlo.reverse"([[RESHAPE]]) <{dimensions = dense<2> : tensor<1xi64>}> : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REVERSE]], [[ZERO]]) {edge_padding_high = dense<[0, 63, 2]> : tensor<3xi64>, edge_padding_low = dense<[0, 4, 0]> : tensor<3xi64>, interior_padding = dense<[0, 3, 0]> : tensor<3xi64>} : (tensor<4x16x1022xf32>, tensor) -> tensor<4x128x1024xf32> + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REVERSE]], [[ZERO]]) <{edge_padding_high = dense<[0, 63, 2]> : tensor<3xi64>, edge_padding_low = dense<[0, 4, 0]> : tensor<3xi64>, interior_padding = dense<[0, 3, 0]> : tensor<3xi64>}> : (tensor<4x16x1022xf32>, tensor) -> tensor<4x128x1024xf32> %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 1, end_mask = 4} : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> // CHECK: return [[PAD]] @@ -4744,10 +4744,7 @@ func.func @strided_slice_grad_all_masks(%grad: tensor<1x4x8x8x10x2x1xf32>) -> te // CHECK-LABEL: @tensor_scatter_update func.func @tensor_scatter_update(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({ - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: mhlo.return %arg4 : tensor - // CHECK: }) + // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) // CHECK-SAME: indices_are_sorted = false // CHECK-SAME: scatter_dimension_numbers // CHECK-SAME: update_window_dims = [1] @@ -4755,6 +4752,9 @@ func.func @tensor_scatter_update(%tensor: tensor, %indices: tensor, %arg4: tensor): + // CHECK: mhlo.return %arg4 : tensor + // CHECK: }) %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -4764,7 +4764,7 @@ func.func @tensor_scatter_update(%tensor: tensor, %indices: tensor, %indices: tensor<2x1xi32>, %updates: tensor) -> tensor<4x3xi32> { // CHECK: mhlo.constant dense<[2, 3]> : tensor<2xi64> - // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg2, %0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xi64>) -> tensor<2x3xi32> + // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg2, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xi64>) -> tensor<2x3xi32> // CHECK: "mhlo.scatter" %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<4x3xi32>, tensor<2x1xi32>, tensor) -> tensor<4x3xi32> func.return %0 : tensor<4x3xi32> @@ -4774,11 +4774,7 @@ func.func @tensor_scatter_update_scalar_update(%tensor: tensor<4x3xi32>, %indice // CHECK-LABEL: @tensor_scatter_add func.func @tensor_scatter_add(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({ - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: %1 = mhlo.add %arg3, %arg4 : tensor - // CHECK: mhlo.return %1 : tensor - // CHECK: }) + // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) // CHECK-SAME: indices_are_sorted = false // CHECK-SAME: scatter_dimension_numbers // CHECK-SAME: update_window_dims = [1] @@ -4786,6 +4782,10 @@ func.func @tensor_scatter_add(%tensor: tensor, %indices: tensor, %arg4: tensor): + // CHECK: %1 = mhlo.add %arg3, %arg4 : tensor + // CHECK: mhlo.return %1 : tensor + // CHECK: }) %0 = "tf.TensorScatterAdd"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -4795,8 +4795,8 @@ func.func @tensor_scatter_add(%tensor: tensor, %indices: tensor, %indices: tensor<2x1xi32>, %updates: tensor) -> tensor<4x3xi32> { // CHECK: mhlo.constant dense<[2, 3]> : tensor<2xi64> - // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg2, %0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xi64>) -> tensor<2x3xi32> - // CHECK: "mhlo.scatter" + // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg2, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xi64>) -> tensor<2x3xi32> + // CHECK: "mhlo.scatter %0 = "tf.TensorScatterAdd"(%tensor, %indices, %updates) : (tensor<4x3xi32>, tensor<2x1xi32>, tensor) -> tensor<4x3xi32> func.return %0 : tensor<4x3xi32> } @@ -4805,11 +4805,7 @@ func.func @tensor_scatter_add_scalar_update(%tensor: tensor<4x3xi32>, %indices: // CHECK-LABEL: @tensor_scatter_sub func.func @tensor_scatter_sub(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({ - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: %1 = mhlo.subtract %arg3, %arg4 : tensor - // CHECK: mhlo.return %1 : tensor - // CHECK: }) + // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) // CHECK-SAME: indices_are_sorted = false // CHECK-SAME: scatter_dimension_numbers // CHECK-SAME: update_window_dims = [1] @@ -4817,6 +4813,10 @@ func.func @tensor_scatter_sub(%tensor: tensor, %indices: tensor, %arg4: tensor): + // CHECK: %1 = mhlo.subtract %arg3, %arg4 : tensor + // CHECK: mhlo.return %1 : tensor + // CHECK: }) %0 = "tf.TensorScatterSub"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -4825,11 +4825,7 @@ func.func @tensor_scatter_sub(%tensor: tensor, %indices: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({ - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: %1 = mhlo.minimum %arg3, %arg4 : tensor - // CHECK: mhlo.return %1 : tensor - // CHECK: }) + // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) // CHECK-SAME: indices_are_sorted = false // CHECK-SAME: scatter_dimension_numbers // CHECK-SAME: update_window_dims = [1] @@ -4837,6 +4833,10 @@ func.func @tensor_scatter_min(%tensor: tensor, %indices: tensor, %arg4: tensor): + // CHECK: %1 = mhlo.minimum %arg3, %arg4 : tensor + // CHECK: mhlo.return %1 : tensor + // CHECK: }) %0 = "tf.TensorScatterMin"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -4845,11 +4845,7 @@ func.func @tensor_scatter_min(%tensor: tensor, %indices: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({ - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: %1 = mhlo.maximum %arg3, %arg4 : tensor - // CHECK: mhlo.return %1 : tensor - // CHECK: }) + // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) // CHECK-SAME: indices_are_sorted = false // CHECK-SAME: scatter_dimension_numbers // CHECK-SAME: update_window_dims = [1] @@ -4857,6 +4853,10 @@ func.func @tensor_scatter_max(%tensor: tensor, %indices: tensor, %arg4: tensor): + // CHECK: %1 = mhlo.maximum %arg3, %arg4 : tensor + // CHECK: mhlo.return %1 : tensor + // CHECK: }) %0 = "tf.TensorScatterMax"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor func.return %0 : tensor } @@ -4894,11 +4894,11 @@ func.func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { // CHECK-DAG: [[SHAPE:%.*]] = mhlo.constant dense<16> : tensor<1xi64> // CHECK-DAG: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor // CHECK-DAG: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor - // CHECK: [[RNG:%.*]] = "mhlo.rng"([[LOWER]], [[UPPER]], [[SHAPE]]) {rng_distribution = #mhlo.rng_distribution} - // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) ({ + // CHECK: [[RNG:%.*]] = "mhlo.rng"([[LOWER]], [[UPPER]], [[SHAPE]]) <{rng_distribution = #mhlo.rng_distribution}> + // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) <{dimension = -1 : i64, is_stable = {{.*}}}> ({ // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor, {{.*}}: tensor, {{.*}}: tensor): // CHECK: mhlo.compare LT, [[ARG1]], [[ARG2]], TOTALORDER - // CHECK: }) {dimension = -1 : i64, is_stable = {{.*}}} : (tensor<16xi32>, tensor<16xf32>) -> (tensor<16xi32>, tensor<16xf32>) + // CHECK: }) : (tensor<16xi32>, tensor<16xf32>) -> (tensor<16xi32>, tensor<16xf32>) // CHECK: return [[SORT]]#1 %0 = "tf.RandomShuffle"(%input) : (tensor<16xf32>) -> (tensor<16xf32>) func.return %0: tensor<16xf32> @@ -4921,12 +4921,12 @@ func.func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf3 // CHECK-LABEL: @random_shuffle_3D // CHECK-SAME: [[INPUT:%.*]]: tensor<4x?x16xf32> func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { - // CHECK: [[INDICES:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + // CHECK: [[INDICES:%.*]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xi32> // CHECK-DAG: [[RNG_SHAPE:%.*]] = mhlo.constant dense<4> : tensor<1xi64> // CHECK-DAG: [[RNG_LOWER:%.*]] = mhlo.constant dense<0> : tensor // CHECK-DAG: [[RNG_UPPER:%.*]] = mhlo.constant dense<4> : tensor - // CHECK: [[SWAPS:%.*]] = "mhlo.rng"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) {rng_distribution = #mhlo.rng_distribution} + // CHECK: [[SWAPS:%.*]] = "mhlo.rng"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) <{rng_distribution = #mhlo.rng_distribution}> // CHECK: [[IV_INIT:%.*]] = mhlo.constant dense<0> : tensor @@ -4935,10 +4935,10 @@ func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK: [[CMP:%.*]] = mhlo.compare LT, [[ITER_ARG0]], [[LIMIT]], NOTYPE // CHECK: mhlo.return [[CMP]] // CHECK: } do { - // CHECK: [[SRC_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[ITER_ARG0]]) {slice_sizes = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<1xi32> - // CHECK: [[SWP_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG1]], [[ITER_ARG0]]) {slice_sizes = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<1xi32> + // CHECK: [[SRC_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[ITER_ARG0]]) <{slice_sizes = dense<1> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<1xi32> + // CHECK: [[SWP_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG1]], [[ITER_ARG0]]) <{slice_sizes = dense<1> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<1xi32> // CHECK: [[SWP:%.*]] = mhlo.reshape [[SWP_IDX]] : (tensor<1xi32>) -> tensor - // CHECK: [[TGT_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[SWP]]) {slice_sizes = dense<1> : tensor<1xi64>} + // CHECK: [[TGT_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[SWP]]) <{slice_sizes = dense<1> : tensor<1xi64>}> // CHECK: [[INDICES1:%.*]] = mhlo.dynamic_update_slice [[ITER_ARG2]], [[TGT_IDX]], [[ITER_ARG0]] : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> // CHECK: [[INDICES2:%.*]] = mhlo.dynamic_update_slice [[INDICES1]], [[SRC_IDX]], [[SWP]] : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> // CHECK: [[ONE:%.*]] = mhlo.constant dense<1> : tensor @@ -4952,7 +4952,7 @@ func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK: [[INDEX_CAST:%.*]] = arith.index_cast [[SHAPE_DIM]] : index to i64 // CHECK: [[FROM_ELEMENTS:%.*]] = tensor.from_elements [[INDEX_CAST]] : tensor<1xi64> // CHECK: [[CONSTANT2:%.*]] = mhlo.constant dense<16> : tensor<1xi64> - // CHECK: [[CONCATENATE:%.*]] = "mhlo.concatenate"([[CONSTANT1]], [[FROM_ELEMENTS]], [[CONSTANT2]]) {dimension = 0 : i64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> + // CHECK: [[CONCATENATE:%.*]] = "mhlo.concatenate"([[CONSTANT1]], [[FROM_ELEMENTS]], [[CONSTANT2]]) <{dimension = 0 : i64}> : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> // CHECK: [[DYNAMIC_GATHER:%.*]] = "mhlo.dynamic_gather"([[INPUT]], [[WHILE_OUT]]#2, [[CONCATENATE]]) // CHECK-SAME: dimension_numbers = // CHECK-SAME: offset_dims = [1, 2] @@ -4978,13 +4978,13 @@ func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK-SAME: [[ARG:%.+]]: tensor<2x12x21x7xf16> // CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x12x21x7xf16>) -> tensor<2x12x21x7xf32> // CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ({ +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> // CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): // CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] // CHECK: mhlo.return [[ADD]] // CHECK: }) -// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> // CHECK-SAME: -> tensor<2x3x5x7xf32> // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] @@ -5004,13 +5004,13 @@ func.func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7 // CHECK-SAME: [[ARG:%.+]]: tensor<2x4x12x21x7xf16> // CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x12x21x7xf32> // CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ({ +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> // CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): // CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] // CHECK: mhlo.return [[ADD]] // CHECK: }) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> // CHECK-SAME: -> tensor<2x4x3x5x7xf32> // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] @@ -5030,13 +5030,13 @@ func.func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x // CHECK-SAME: [[ARG:%.+]]: tensor<2x7x12x21xf16> // CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x7x12x21xf16>) -> tensor<2x7x12x21xf32> // CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ({ +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2]> +// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> // CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): // CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] // CHECK: mhlo.return [[ADD]] // CHECK: }) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2]> -// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> // CHECK-SAME: -> tensor<2x7x3x5xf32> // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] @@ -5056,13 +5056,13 @@ func.func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf // CHECK-SAME: [[ARG:%.+]]: tensor<2x7x4x12x21xf16> // CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x12x21xf32> // CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ({ +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 2]> +// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> // CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): // CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] // CHECK: mhlo.return [[ADD]] // CHECK: }) -// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 2]> -// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> // CHECK-SAME: -> tensor<2x7x4x3x5xf32> // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] @@ -5081,24 +5081,24 @@ func.func @avgpool_3d_ncdhw_format(%arg0: tensor<2x7x4x12x21xf16>) -> tensor<2x7 // CHECK-LABEL: @avgpool_same_padding( // CHECK-SAME: %[[ARG0:.*]]: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ({ +// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): // CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor // CHECK: mhlo.return %[[SUM1]] : tensor // CHECK: }) +// CHECK-SAME: -> tensor<2x4x6x7xf32> +// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) // CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> // CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> // CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> -// CHECK-SAME: -> tensor<2x4x6x7xf32> -// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ({ // CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): // CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor // CHECK: mhlo.return %[[SUM2]] : tensor // CHECK: }) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> // CHECK-SAME: -> tensor<2x4x6x7xf32> // CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] : tensor<2x4x6x7xf32> // CHECK: return %[[RESULT]] : tensor<2x4x6x7xf32> @@ -5113,24 +5113,24 @@ func.func @avgpool_same_padding(%arg0: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7x // CHECK-LABEL: @avgpool_3d_same_padding( // CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ({ +// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): // CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor // CHECK: mhlo.return %[[SUM1]] : tensor // CHECK: }) +// CHECK-SAME: -> tensor<2x4x4x6x7xf32> +// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) // CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> // CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> // CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> -// CHECK-SAME: -> tensor<2x4x4x6x7xf32> -// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ({ // CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): // CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor // CHECK: mhlo.return %[[SUM2]] : tensor // CHECK: }) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> // CHECK-SAME: -> tensor<2x4x4x6x7xf32> // CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] // CHECK: return %[[RESULT]] : tensor<2x4x4x6x7xf32> @@ -5158,13 +5158,13 @@ func.func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4 // CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> // CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> // CHECK-SAME: -> tensor<10x25x33x64xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({ +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<1> // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): // CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor // CHECK: mhlo.return %[[SUM]] : tensor // CHECK: }) -// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<1> // CHECK-SAME: -> tensor<10x24x32x64xf32> // CHECK: return %[[RESULT]] : tensor<10x24x32x64xf32> func.func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { @@ -5190,13 +5190,13 @@ func.func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor< // CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> // CHECK-SAME: interior_padding = dense<[0, 0, 1, 1, 0]> // CHECK-SAME: -> tensor<10x8x25x33x64xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({ +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<1> // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): // CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor // CHECK: mhlo.return %[[SUM]] : tensor // CHECK: }) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<1> // CHECK-SAME: -> tensor<10x8x24x32x64xf32> // CHECK: return %[[RESULT]] : tensor<10x8x24x32x64xf32> func.func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { @@ -5215,14 +5215,14 @@ func.func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> te // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ({ +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> +// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): // CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor // CHECK: mhlo.return %[[SUM1]] : tensor // CHECK: }) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> -// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> // CHECK-SAME: -> tensor<2x4x7x9xf32> // CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x4x7x9xf32> // CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) @@ -5230,13 +5230,13 @@ func.func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> te // CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> // CHECK-SAME: interior_padding = dense<[0, 3, 3, 0]> // CHECK-SAME: -> tensor<2x14x27x9xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({ +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> +// CHECK-SAME: window_strides = dense<1> // CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): // CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor // CHECK: mhlo.return %[[SUM2]] : tensor // CHECK: }) -// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> -// CHECK-SAME: window_strides = dense<1> // CHECK-SAME: -> tensor<2x13x25x9xf32> // CHECK: return %[[RESULT]] : tensor<2x13x25x9xf32> func.func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { @@ -5256,14 +5256,14 @@ func.func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ({ +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> +// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): // CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor // CHECK: mhlo.return %[[SUM1]] : tensor // CHECK: }) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> -// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> // CHECK-SAME: -> tensor<2x8x4x7x9xf32> // CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x8x4x7x9xf32> // CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) @@ -5271,13 +5271,13 @@ func.func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x // CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> // CHECK-SAME: interior_padding = dense<[0, 0, 3, 3, 0]> // CHECK-SAME: -> tensor<2x8x14x27x9xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({ +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> +// CHECK-SAME: window_strides = dense<1> // CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): // CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor // CHECK: mhlo.return %[[SUM2]] : tensor // CHECK: }) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> -// CHECK-SAME: window_strides = dense<1> // CHECK-SAME: -> tensor<2x8x13x25x9xf32> // CHECK: return %[[RESULT]] : tensor<2x8x13x25x9xf32> func.func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { @@ -5296,14 +5296,14 @@ func.func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor< // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ({ +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> +// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): // CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor // CHECK: mhlo.return %[[SUM1]] : tensor // CHECK: }) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1]]> -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> -// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> // CHECK-SAME: -> tensor<2x9x4x7xf32> // CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x4x7xf32> // CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) @@ -5311,13 +5311,13 @@ func.func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor< // CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1]> // CHECK-SAME: interior_padding = dense<[0, 0, 3, 3]> // CHECK-SAME: -> tensor<2x9x14x27xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({ +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> +// CHECK-SAME: window_strides = dense<1> // CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): // CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor // CHECK: mhlo.return %[[SUM2]] : tensor // CHECK: }) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> -// CHECK-SAME: window_strides = dense<1> // CHECK-SAME: -> tensor<2x9x13x25xf32> // CHECK: return %[[RESULT]] : tensor<2x9x13x25xf32> func.func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { @@ -5337,14 +5337,14 @@ func.func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13 // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ({ +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 0], [0, 1], [1, 1]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> +// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): // CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor // CHECK: mhlo.return %[[SUM1]] : tensor // CHECK: }) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 0], [0, 1], [1, 1]]> -// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> -// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> // CHECK-SAME: -> tensor<2x9x8x4x7xf32> // CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x8x4x7xf32> // CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) @@ -5352,13 +5352,13 @@ func.func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13 // CHECK-SAME: edge_padding_low = dense<[0, 0, 0, 1, 1]> // CHECK-SAME: interior_padding = dense<[0, 0, 0, 3, 3]> // CHECK-SAME: -> tensor<2x9x8x14x27xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({ +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) +// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> +// CHECK-SAME: window_strides = dense<1> : tensor<5xi64> // CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): // CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor // CHECK: mhlo.return %[[SUM2]] : tensor // CHECK: }) -// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> -// CHECK-SAME: window_strides = dense<1> : tensor<5xi64> // CHECK-SAME: -> tensor<2x9x8x13x25xf32> // CHECK: return %[[RESULT]] : tensor<2x9x8x13x25xf32> func.func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { @@ -5387,13 +5387,13 @@ func.func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor< // CHECK-SAME: -> tensor<10x25x33x64xbf16> // CHECK: %[[REDUCE_WINDOW_INPUT_CONVERTED:.*]] = mhlo.convert %[[REDUCE_WINDOW_INPUT]] : (tensor<10x25x33x64xbf16>) -> tensor<10x25x33x64xf32> // CHECK: %[[ZERO_F32:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) ({ +// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) +// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<1> // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): // CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor // CHECK: mhlo.return %[[SUM]] : tensor // CHECK: }) -// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<1> // CHECK-SAME: -> tensor<10x24x32x64xf32> // CHECK: %[[RESULT_CONVERTED:.*]] = mhlo.convert %[[RESULT]] : (tensor<10x24x32x64xf32>) -> tensor<10x24x32x64xbf16> // CHECK: return %[[RESULT_CONVERTED]] : tensor<10x24x32x64xbf16> @@ -5422,8 +5422,8 @@ func.func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { // CHECK-LABEL: inplace_update_one func.func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> { // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> + // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> // CHECK-DAG: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] // CHECK-DAG: [[UPDATE:%.+]] = mhlo.dynamic_update_slice %arg0, [[SLICE2]], [[RESHAPE1]], [[CST]] %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32> @@ -5437,12 +5437,12 @@ func.func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %a // CHECK-LABEL: inplace_update_three func.func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> { // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[SLICE3:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[SLICE4:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} - // CHECK-DAG: [[SLICE5:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} - // CHECK-DAG: [[SLICE6:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> + // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> + // CHECK-DAG: [[SLICE3:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> + // CHECK-DAG: [[SLICE4:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> + // CHECK-DAG: [[SLICE5:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> + // CHECK-DAG: [[SLICE6:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> // CHECK-DAG: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] // CHECK-DAG: [[RESHAPE2:%.+]] = mhlo.reshape [[SLICE2]] // CHECK-DAG: [[RESHAPE3:%.+]] = mhlo.reshape [[SLICE3]] @@ -5459,9 +5459,9 @@ func.func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf3 // CHECK-LABEL: xla_dynamic_update_slice func.func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> { - // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> // CHECK: [[RESHAPE0:%.+]] = mhlo.reshape [[SLICE0]] : (tensor<1xi32>) -> tensor - // CHECK: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> // CHECK: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] : (tensor<1xi32>) -> tensor // CHECK: [[DUS:%.+]] = mhlo.dynamic_update_slice %arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]] : (tensor<4x16xf32>, tensor<2x4xf32>, tensor, tensor) -> tensor<4x16xf32> // CHECK: return [[DUS]] @@ -5473,7 +5473,7 @@ func.func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf // CHECK-LABEL: xla_dynamic_update_slice2 func.func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<1xi32>) -> tensor<4xf32> { - // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<1xi32>) -> tensor<1xi32> // CHECK: [[RESHAPE0:%.+]] = mhlo.reshape [[SLICE0]] : (tensor<1xi32>) -> tensor // CHECK: [[DUS:%.+]] = mhlo.dynamic_update_slice %arg0, %arg1, [[RESHAPE0]] : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> // CHECK: return [[DUS]] @@ -5512,11 +5512,11 @@ func.func @cumsum_static(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[X]] : tensor<4xf32> // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ({ + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor // CHECK: mhlo.return [[SUM]] : tensor - // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[REDUCE]] : tensor<4xf32> // CHECK: return [[CONVERT_REDUCE]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor @@ -5532,12 +5532,12 @@ func.func @cumsum_exclusive(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[X]] : tensor<4xf32> // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ({ + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor // CHECK: mhlo.return [[SUM]] : tensor - // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) <{edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<4xf32>, tensor) -> tensor<4xf32> // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[PAD]] : tensor<4xf32> // CHECK: return [[CONVERT_REDUCE]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor @@ -5551,16 +5551,16 @@ func.func @cumsum_exclusive(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK-SAME: [[X:%.*]]: tensor<4xf32> func.func @cumsum_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor - // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[REVERSE1]] : tensor<4xf32> // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ({ + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor // CHECK: mhlo.return [[SUM]] : tensor - // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[REDUCE]] : tensor<4xf32> - // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> // CHECK: return [[REVERSE_BACK]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = true} : (tensor<4xf32>, tensor) -> tensor<4xf32> @@ -5573,17 +5573,17 @@ func.func @cumsum_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK-SAME: [[X:%.*]]: tensor<4xf32> func.func @cumsum_exclusive_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor - // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[REVERSE1]] : tensor<4xf32> // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ({ + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor // CHECK: mhlo.return [[SUM]] : tensor - // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) <{edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<4xf32>, tensor) -> tensor<4xf32> // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[PAD]] : tensor<4xf32> - // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> // CHECK: return [[REVERSE_BACK]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = true} : (tensor<4xf32>, tensor) -> tensor<4xf32> @@ -5619,7 +5619,7 @@ func.func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor) -> tensor<4xf32> { // CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ({ + // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) // CHECK: mhlo.mul %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor %1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> @@ -5857,7 +5857,7 @@ func.func @xla_conv_v2(%lhs: tensor<8x4x16x16x16xf32>, %rhs: tensor<4x3x3x16x16x // CHECK-LABEL: @xladot_matmul( // CHECK-SAME: %[[LHS:.*]]: tensor<64x32xi8>, %[[RHS:.*]]: tensor<32x16xi8>) -> tensor<64x16xi32> func.func @xladot_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> tensor<64x16xi32> { - // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) { + // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-NOT: lhs_batching_dimensions = // CHECK-NOT: rhs_batching_dimensions = @@ -5877,7 +5877,7 @@ func.func @xladot_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> te // CHECK-LABEL: @xladotv2_matmul( // CHECK-SAME: %[[LHS:.*]]: tensor<64x32xi8>, %[[RHS:.*]]: tensor<32x16xi8>) -> tensor<64x16xi32> func.func @xladotv2_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> tensor<64x16xi32> { - // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) { + // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-NOT: lhs_batching_dimensions = // CHECK-NOT: rhs_batching_dimensions = @@ -5916,7 +5916,7 @@ func.func @xla_dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi3 // CHECK-LABEL: xla_dynamic_slice_i32_consts func.func @xla_dynamic_slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor - // CHECK: "mhlo.dynamic_slice"(%arg0, %[[START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> + // CHECK: "mhlo.dynamic_slice"(%arg0, %[[START]]) <{slice_sizes = dense<2> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<2xi32> %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> @@ -5954,7 +5954,7 @@ func.func @xla_dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tenso // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> // CHECK: %[[RESHAPED_START2:.*]] = mhlo.reshape %[[SLICED_START2]] : (tensor<1xi64>) -> tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> // CHECK: return %[[RESULT]] : tensor<1x4xi32> %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) %0 = "tf.XlaDynamicSlice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> @@ -5996,11 +5996,11 @@ func.func @test_xla_reduce_window(%arg0: tensor<7xf32>, %arg1: tensor) -> t %cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> %cst_2 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> %cst_3 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[REDUCE:.*]] = "mhlo.reduce_window"(%arg0, %arg1) ({ + // CHECK: %[[REDUCE:.*]] = "mhlo.reduce_window"(%arg0, %arg1) <{base_dilations = dense<3> : tensor<1xi64>, padding = dense<0> : tensor<1x2xi64>, window_dilations = dense<4> : tensor<1xi64>, window_dimensions = dense<1> : tensor<1xi64>, window_strides = dense<2> : tensor<1xi64>}> ({ // CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) // CHECK-NEXT: %[[SUM:.*]] = func.call @sum_reducer3(%[[ARG0]], %[[ARG1]]){{.*}} // CHECK-NEXT: mhlo.return %[[SUM]] : tensor - // CHECK-NEXT: }) {base_dilations = dense<3> : tensor<1xi64>, padding = dense<0> : tensor<1x2xi64>, window_dilations = dense<4> : tensor<1xi64>, window_dimensions = dense<1> : tensor<1xi64>, window_strides = dense<2> : tensor<1xi64>} : (tensor<7xf32>, tensor) -> tensor<10xf32> + // CHECK-NEXT: }) : (tensor<7xf32>, tensor) -> tensor<10xf32> // CHECK-NEXT: return %[[REDUCE]] %0 = "tf.XlaReduceWindow"(%arg0, %arg1, %cst_0, %cst_1, %cst_2, %cst_3, %cst) {computation = @sum_reducer3} : (tensor<7xf32>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> @@ -6020,11 +6020,11 @@ func.func private @sum_reducer3(%arg0: tensor, %arg1: tensor) -> tenso // CHECK-LABEL: @xlasort_int // CHECK-SAME: %[[INPUT:.*]]: tensor<16xi32> func.func @xlasort_int(%input: tensor<16xi32>) -> (tensor<16xi32>) { - // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) ({ + // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) <{dimension = -1 : i64, is_stable = false}> ({ // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare LT, %[[LHS]], %[[RHS]], NOTYPE // CHECK-NEXT: mhlo.return %[[CMP]] - // CHECK-NEXT: }) {dimension = -1 : i64, is_stable = false} : (tensor<16xi32>) -> tensor<16xi32> + // CHECK-NEXT: }) : (tensor<16xi32>) -> tensor<16xi32> // CHECK-NEXT: return %[[SORT]] %output = "tf.XlaSort"(%input) : (tensor<16xi32>) -> (tensor<16xi32>) func.return %output : tensor<16xi32> @@ -6035,11 +6035,11 @@ func.func @xlasort_int(%input: tensor<16xi32>) -> (tensor<16xi32>) { // CHECK-LABEL: @xlasort_float // CHECK-SAME: %[[INPUT:.*]]: tensor<8xf64> func.func @xlasort_float(%input: tensor<8xf64>) -> (tensor<8xf64>) { - // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) ({ + // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) <{dimension = -1 : i64, is_stable = false}> ({ // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare LT, %[[LHS]], %[[RHS]], TOTALORDER // CHECK-NEXT: mhlo.return %[[CMP]] - // CHECK-NEXT: }) {dimension = -1 : i64, is_stable = false} : (tensor<8xf64>) -> tensor<8xf64> + // CHECK-NEXT: }) : (tensor<8xf64>) -> tensor<8xf64> // CHECK-NEXT: return %[[SORT]] %output = "tf.XlaSort"(%input) : (tensor<8xf64>) -> (tensor<8xf64>) func.return %output : tensor<8xf64> @@ -6067,7 +6067,7 @@ func.func @xla_rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tens %cst = "tf.Const"() {value = dense<[10, 12]> : tensor<2xi32>} : () -> tensor<2xi32> // CHECK-NEXT: %1 = mhlo.constant dense<3> : tensor %cst_0 = "tf.Const"() {value = dense<3> : tensor} : () -> tensor - // CHECK-NEXT: %[[OUTPUT_STATE:.*]], %[[OUTPUT:.*]] = "mhlo.rng_bit_generator"(%[[STATE]]) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) + // CHECK-NEXT: %[[OUTPUT_STATE:.*]], %[[OUTPUT:.*]] = "mhlo.rng_bit_generator"(%[[STATE]]) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) // CHECK-NEXT: return %[[OUTPUT_STATE]], %[[OUTPUT]] : tensor<2xui64>, tensor<10x12xui32> %output_key, %output = "tf.XlaRngBitGenerator"(%cst_0, %arg0, %cst) : (tensor, tensor<2xui64>, tensor<2xi32>) -> (tensor<2xui64>, tensor<10x12xui32>) func.return %output_key, %output : tensor<2xui64>, tensor<10x12xui32> @@ -6123,11 +6123,11 @@ func.func private @sum_reducer2(%arg0: tensor, %arg1: tensor) -> tenso func.func @xla_variadic_sort(%arg0: tensor<2x3x4xui8>) -> tensor<2x3x4xui8> attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} { // CHECK-NEXT: {{.*}} = mhlo.constant dense<0> : tensor %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) ({ + // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) <{dimension = 0 : i64, is_stable = false}> ({ // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) // CHECK-NEXT: %[[CMP:.*]] = func.call @compare_lt(%[[LHS]], %[[RHS]]) : (tensor, tensor) -> tensor // CHECK-NEXT: mhlo.return %[[CMP]] - // CHECK-NEXT: }) {dimension = 0 : i64, is_stable = false} : (tensor<2x3x4xui8>) -> tensor<2x3x4xui8> + // CHECK-NEXT: }) : (tensor<2x3x4xui8>) -> tensor<2x3x4xui8> // CHECK-NEXT: return %[[SORT]] %0 = "tf.XlaVariadicSort"(%arg0, %cst) {_XlaHasReferenceVars = false, comparator = @compare_lt, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", is_stable = false} : (tensor<2x3x4xui8>, tensor) -> tensor<2x3x4xui8> func.return %0 : tensor<2x3x4xui8> @@ -6172,7 +6172,7 @@ func.func @test_xla_select_and_scatter(%arg0: tensor<4x5x1x1xbf16>, %arg1: tenso %cst = "tf.Const"() {value = dense<0> : tensor<4x2xi32>} : () -> tensor<4x2xi32> %cst_0 = "tf.Const"() {value = dense<[2, 2, 1, 1]> : tensor<4xi32>} : () -> tensor<4xi32> %cst_1 = "tf.Const"() {value = dense<[2, 3, 1, 1]> : tensor<4xi32>} : () -> tensor<4xi32> - // CHECK: %[[SELECT_AND_SCATTER:.*]] = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + // CHECK: %[[SELECT_AND_SCATTER:.*]] = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) <{padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[2, 3, 1, 1]> : tensor<4xi64>, window_strides = dense<[2, 2, 1, 1]> : tensor<4xi64>}> ({ // CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) // CHECK-NEXT: %[[RES:.*]] = func.call @ge_select(%[[ARG0]], %[[ARG1]]){{.*}} // CHECK-NEXT: mhlo.return %[[RES]] : tensor @@ -6180,7 +6180,7 @@ func.func @test_xla_select_and_scatter(%arg0: tensor<4x5x1x1xbf16>, %arg1: tenso // CHECK-NEXT: ^{{.*}}(%[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor) // CHECK-NEXT: %[[RES:.*]] = func.call @add_scatter(%[[ARG2]], %[[ARG3]]){{.*}} // CHECK-NEXT: mhlo.return %[[RES]] : tensor - // CHECK-NEXT: }) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[2, 3, 1, 1]> : tensor<4xi64>, window_strides = dense<[2, 2, 1, 1]> : tensor<4xi64>} : (tensor<4x5x1x1xbf16>, tensor<2x2x1x1xbf16>, tensor) -> tensor + // CHECK-NEXT: }) : (tensor<4x5x1x1xbf16>, tensor<2x2x1x1xbf16>, tensor) -> tensor // CHECK-NEXT: return %[[SELECT_AND_SCATTER]] %0 = "tf.XlaSelectAndScatter"(%arg0, %cst_1, %cst_0, %cst, %arg1, %arg2) {scatter = @add_scatter, select = @ge_select} : (tensor<4x5x1x1xbf16>, tensor<4xi32>, tensor<4xi32>, tensor<4x2xi32>, tensor<2x2x1x1xbf16>, tensor) -> tensor func.return %0 : tensor diff --git a/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization.mlir b/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization.mlir index e6623350380fcb..fa0170a2b6b7a9 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization.mlir @@ -6,7 +6,7 @@ func.func @allowsMHLO() -> (tensor<8x64x32x4xcomplex> {mhlo.sharding = ""}) { %0 = mhlo.constant dense<(1.000000e+00,-1.000000e+00)> : tensor<128x32x4xcomplex> %1 = mhlo.constant dense<(1.000000e+00,1.000000e+00)> : tensor<8x64x128xcomplex> - %2 = "mhlo.einsum"(%1, %0) {einsum_config = "abc,cde->abde"} : (tensor<8x64x128xcomplex>, tensor<128x32x4xcomplex>) -> tensor<8x64x32x4xcomplex> + %2 = "mhlo.einsum"(%1, %0) <{einsum_config = "abc,cde->abde"}> : (tensor<8x64x128xcomplex>, tensor<128x32x4xcomplex>) -> tensor<8x64x32x4xcomplex> return %2 : tensor<8x64x32x4xcomplex> } @@ -53,7 +53,7 @@ func.func @nonstatic_shape_mhlo() -> tensor attributes {tf.entry_function %1 = mhlo.convert %0 : (tensor) -> tensor %2 = mhlo.reshape %1 : (tensor) -> tensor<1xi64> // expected-error @+1 {{Node `mhlo.dynamic_iota` must have compile-time constant}} - %3 = "mhlo.dynamic_iota"(%2) {iota_dimension = 0 : i64} : (tensor<1xi64>) -> tensor + %3 = "mhlo.dynamic_iota"(%2) <{iota_dimension = 0 : i64}> : (tensor<1xi64>) -> tensor %4 = mhlo.multiply %3, %3 : tensor return %4 : tensor } diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 21cdf1203a3554..31b6aa272faf1d 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -472,9 +472,7 @@ tf_proto_library( cc_library( name = "passes", - visibility = [ - "//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private. - ], + visibility = ["//visibility:private"], deps = [ "//tensorflow/compiler/mlir/tfrt:tf_to_tfrt", ], @@ -695,6 +693,7 @@ cc_library( hdrs = ["backend_compiler.h"], deps = [ "//tensorflow/core/tfrt/runtime", + "@com_google_absl//absl/status", "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/mlir/tfrt/backend_compiler.h b/tensorflow/compiler/mlir/tfrt/backend_compiler.h index 0e959f04f43554..7167c8ef18e0ea 100644 --- a/tensorflow/compiler/mlir/tfrt/backend_compiler.h +++ b/tensorflow/compiler/mlir/tfrt/backend_compiler.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_BACKEND_COMPILER_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_BACKEND_COMPILER_H_ +#include "absl/status/status.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "tensorflow/core/tfrt/runtime/runtime.h" diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD index ce69fa85189423..bfc93b9252ccbf 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD @@ -88,9 +88,7 @@ td_library( "tf_mlrt_tpu_ops.td", ], includes = ["."], - visibility = [ - "//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private. - ], + visibility = ["//visibility:private"], deps = [ ":mlrt_td_files", ":tf_mlrt_td_files", diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td index 6ff38dda69bd85..fcbf2358b3b936 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td @@ -462,7 +462,8 @@ def IfrtLoadVariableOp: TensorflowMlrt_Op<"ifrt_load_variable", [Pure]> { `tf.IfrtLoadVariableOp` converts the tensor into an IFRT array based on device and sharding configuration specified in `VariableDeviceShardingConfigProto`. - This op returns a scalar string tensor as a key for user to look for the loaded array. + This op returns a scalar string tensor as a key for user to look for the loaded array + and a future containing the restored tensor. }]; let arguments = (ins @@ -472,7 +473,8 @@ def IfrtLoadVariableOp: TensorflowMlrt_Op<"ifrt_load_variable", [Pure]> { ); let results = (outs - TFTensorType:$array_key + TFTensorType:$array_key, + MlrtFutureType: $tensor_future ); } diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.td index bb567f32106215..0791423a91c17f 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.td +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.td @@ -148,5 +148,39 @@ def TFTPUCompileAndExecuteOp : TensorflowMlrt_Op<"tf_tpu_compile_and_execute", [ }]; } +def TFIfrtLoadVariableOp: TensorflowMlrt_Op<"tf_ifrt_load_variable", [Pure]> { + let summary = "Loads a variable tensor as an IFRT array for mlrt"; + + let description = [{ + This is the MLRT version of tf.IfrtLoadVariableOp. + + This op loads a variable tensor as an IFRT array and binds it with the specified name. + + This op is an replacement of `tf.ReadVariableOp` in the case that a constant + variable tensor is an input to the tpu program invoked by `tf.IfrtCall`. + + After a `tf.ReadVariableOp` is lowered into `tf.IfrtLoadVariableOp`, the `tf.IfrtCall` kernel + will bind the loaded IFRT array by name with the tpu program's input. + + `tf.IfrtLoadVariableOp` converts the tensor into an IFRT array based on device and sharding + configuration specified in `VariableDeviceShardingConfigProto`. + + This op returns a scalar string tensor as a key for user to look for the loaded array + and a future containing the restored tensor. + }]; + + let arguments = (ins + TF_Tensor:$variable, + StrAttr:$device_sharding_config_proto_text, + StrAttr:$name + ); + + let results = (outs + TF_Tensor:$array_key, + MlrtFutureType: $tensor_future + ); +} + + #endif diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/sink_variable_as_named_array.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/sink_variable_as_named_array.mlir index ba644948c6b06d..dec4b733d25b19 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/ifrt/sink_variable_as_named_array.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/sink_variable_as_named_array.mlir @@ -6,7 +6,7 @@ // // CHECK-LABEL: func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { // CHECK-NEXT: [[HANDLE2:%.*]] = "tf.VarHandleOp" -// CHECK-NEXT: [[KEY:%.*]] = "tf.IfrtLoadVariable"([[HANDLE2]]) +// CHECK-NEXT: [[KEY:%.*]], [[FUTURE:%.*]] = "tf.IfrtLoadVariable"([[HANDLE2]]) // CHECK-SAME: device_sharding_config_proto_text = "sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } device_ids: 0 device_ids: 1 " // CHECK-SAME: name = "__y" // CHECK-NEXT: [[RES:%.*]] = "tf.IfrtCall"([[KEY]], %arg0) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [0 : i32]}> @@ -27,9 +27,9 @@ module { // // CHECK-LABEL: func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { // CHECK: "tf.VarHandleOp" -// CHECK-NEXT: [[VARIABLE:%.*]] = "tf.ReadVariableOp" -// CHECK-NEXT: [[KEY:%.*]] = "tf.IfrtLoadVariable" -// CHECK-NEXT: "tf.MatMul"(%arg0, [[VARIABLE]]) +// CHECK-NOT: [[VARIABLE:%.*]] = "tf.ReadVariableOp" +// CHECK-NEXT: [[KEY:%.*]], [[FUTURE:%.*]] = "tf.IfrtLoadVariable" +// CHECK-NEXT: "tf.MatMul"(%arg0, [[FUTURE]]) // CHECK-NEXT: [[RES:%.*]] = "tf.IfrtCall"(%arg0, [[KEY]]) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [1 : i32]}> // CHECK-NEXT: return [[RES]] : tensor<1x1xf32> // @@ -42,3 +42,22 @@ module { return %result : tensor<1x1xf32> } } + +// ----- +// Variable tensor is only for host +// +// CHECK-LABEL: func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { +// CHECK: "tf.VarHandleOp" +// CHECK-NOT: [[VARIABLE:%.*]] = "tf.ReadVariableOp" +// CHECK-NEXT: [[KEY:%.*]], [[FUTURE:%.*]] = "tf.IfrtLoadVariable" +// CHECK-NEXT: [[RES:%.*]] = "tf.MatMul"(%arg0, [[FUTURE]]) +// CHECK-NEXT: return [[RES]] : tensor<1x1xf32> +// +module { + func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { + %0 = "tf.VarHandleOp"() <{container = "", shared_name = "y"}> : () -> tensor>> + %2 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor<3x1xf32> + %3 = "tf.MatMul"(%arg0, %2) : (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32> + return %3: tensor<1x1xf32> + } +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir new file mode 100644 index 00000000000000..e1ad0aea205007 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir @@ -0,0 +1,20 @@ +// RUN: tf-tfrt-opt -split-input-file -tf-mlrt-rewrite-ifrt-load-variable %s | FileCheck %s + +// Variable is used by both CPU and TPU +// +// CHECK-LABEL: func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> +// CHECK-NEXT: [[HANDLE:%.*]] = "tf.VarHandleOp"() +// CHECK-NEXT: [[ARRAYKEY:%.*]], [[FURTURE:%.*]] = "tf_mlrt.tf_ifrt_load_variable"([[HANDLE]]) +// CHECK-SAME: {device_sharding_config_proto_text = "sharding { }", name = "__y"} : (tensor>>) -> (tensor, !mlrt.future) +// CHECK-NEXT: [[TENSOR:%.*]] = "tf_mlrt.tf_await"([[FURTURE]]) : (!mlrt.future) -> tensor<3x1xf32> +// CHECK-NEXT: "tf.MatMul"(%arg0, [[TENSOR]]) : (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32> +// CHECK-NEXT: "tf.IfrtCall"(%arg0, [[ARRAYKEY]]) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [1 : i32]}> {__tpu_compile_metadata_text = "retvals { sharding { } }"} : (tensor<1x3xf32>, tensor) -> tensor<1x1xf32> +// CHECK-NEXT: return +// + func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { + %0 = "tf.VarHandleOp"() <{container = "", shared_name = "y"}> : () -> tensor>> + %array_key, %tensor = "tf.IfrtLoadVariable"(%0) <{device_sharding_config_proto_text = "sharding { }", name = "__y"}> : (tensor>>) -> (tensor, tensor<3x1xf32>) + %1 = "tf.MatMul"(%arg0, %tensor) : (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32> + %2 = "tf.IfrtCall"(%arg0, %array_key) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [1 : i32]}> {__tpu_compile_metadata_text = "retvals { sharding { } }"} : (tensor<1x3xf32>, tensor) -> tensor<1x1xf32> + return %2 : tensor<1x1xf32> + } diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir index 3cb879dabe97f7..3151daf80ec759 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir @@ -470,7 +470,8 @@ func.func @ifrt_load_variable_test() -> () { // CHECK-NEXT: "tf_mlrt.ifrt_load_variable"([[HANDLE]]) // CHECK-SAME: device_sharding_config_proto_text // CHECK-SAME: name = "__variable" - %1 = "tf.IfrtLoadVariable"(%0) <{device_sharding_config_proto_text = "sharding { } device_ids: 0 device_ids: 1 ", name = "__variable"}> {__op_key = 2: i32, device = "/device:CPU:0"} : (tensor>>) -> (tensor) + %1, %2 = "tf_mlrt.tf_ifrt_load_variable"(%0) {device_sharding_config_proto_text = "sharding { } device_ids: 0 device_ids: 1 ", name = "__variable", __op_key = 2: i32, device = "/device:CPU:0"} : (tensor>>) -> (tensor, !mlrt.future) + // CHECK-NEXT: mlrt.await_all_control // CHECK-NEXT: return func.return } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index 6ef5c011d0a11d..305195e744932f 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -124,9 +124,15 @@ cc_library( ":ifrt_constants", ":ifrt_types", "//tensorflow/compiler/jit:xla_cpu_jit", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_a_m_inc_gen", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_n_z_inc_gen", + "//tensorflow/compiler/mlir/tensorflow:visitor", "//tensorflow/compiler/mlir/tf2xla/api/v2:legalize_tf", + "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:core_cpu_base", @@ -160,6 +166,29 @@ cc_library( ], ) +cc_library( + name = "extract_callback", + srcs = ["extract_callback.cc"], + hdrs = ["extract_callback.h"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_a_m_inc_gen", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_n_z_inc_gen", + "//tensorflow/compiler/mlir/tensorflow:visitor", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + tf_cc_test( name = "tf2hlo_test", srcs = [ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/extract_callback.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/extract_callback.cc new file mode 100644 index 00000000000000..67f10482d16d15 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/extract_callback.cc @@ -0,0 +1,81 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/extract_callback.h" + +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h.inc" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h.inc" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/visitor.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { +namespace ifrt_serving { + +absl::StatusOr> ExtractCallbackModule( + mlir::ModuleOp module, absl::string_view callback_key) { + // Find the entry function name first. + mlir::func::FuncOp callback_entry_func; + module.walk([&](mlir::func::FuncOp func) { + if (func.getSymName().str() == callback_key) { + callback_entry_func = func; + return mlir::WalkResult::skip(); + } + return mlir::WalkResult::advance(); + }); + + if (!callback_entry_func) { + return absl::NotFoundError( + absl::StrCat("Callback key ", callback_key, " not found")); + } + + mlir::StatusScopedDiagnosticHandler diag_handler(module->getContext()); + auto entry_function_name = callback_entry_func.getSymName(); + auto submodule = mlir::TF::CreatePrunedModule(module, entry_function_name); + if (mlir::failed(submodule)) { + return diag_handler.ConsumeStatus(); + } + + // Remove the attribute inherited from saved model loading. They impose + // additional constraint on public functions that are not necessary. + submodule->get()->removeAttr("tf_saved_model.semantics"); + submodule->get().walk([&](mlir::func::FuncOp func) { + if (func.getSymName() == entry_function_name) { + func.setPublic(); + } + }); + return std::move(*submodule); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/extract_callback.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/extract_callback.h new file mode 100644 index 00000000000000..a345d1d881a79e --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/extract_callback.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_EXTRACT_CALLBACK_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_EXTRACT_CALLBACK_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { +namespace ifrt_serving { + +// Extracts a module that consists of a public callback function in name of +// `callback_key` and all its reachables. +absl::StatusOr> ExtractCallbackModule( + mlir::ModuleOp module, absl::string_view callback_key); + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_EXTRACT_CALLBACK_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc index 8646be3554a0d9..ebaf2570bba3f4 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc @@ -92,6 +92,8 @@ CompileAndRegisterIfrtPrograms(absl::string_view model_name, model_name, entry_function_name.str(), *std::move(submodule), ifrt_model_context.GetClient(), &ifrt_model_context.GetThreadPool(), &ifrt_model_context.GetLoadedVariableRegistry(), + &ifrt_model_context.GetRestoreTensorRegistry(), + ifrt_model_context.GetDeviceMgr(), ifrt_model_context.GetShapeRepresentationFn()); // Register the Ifrt program to `ServingExecutableRegistry` so that @@ -145,14 +147,14 @@ absl::Status IfrtBackendCompiler::CompileTensorflow( tensorflow::DumpMlirOpToFile("ifrt_tpu_bct_conversion_before", module); } - // Run backward compat pass so that we can use bridge to do clustering. - auto backward_compat_result = - tensorflow::RunTPUBackwardCompatConversion(module, {}); - if (mlir::failed(backward_compat_result)) { - return diag_handler.Combine( - absl::InternalError("Failed to handle legacy TPU Ops")); + if (tpu_compiler_ != nullptr) { + // Run backward compat pass so that we can use bridge to do clustering. + if (mlir::failed( + tpu_compiler_->RunTPUBackwardCompatConversion(module, {}))) { + return diag_handler.Combine( + absl::InternalError("Failed to handle legacy TPU Ops")); + } } - if (VLOG_IS_ON(1)) { tensorflow::DumpMlirOpToFile("ifrt_tpu_bct_conversion_after", module); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h index c227027a48ba17..2407fe7cc3546c 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_BACKEND_COMPILER_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_BACKEND_COMPILER_H_ - #include "absl/status/status.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/backend_compiler.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h" #include "tensorflow/core/tfrt/runtime/runtime.h" namespace tensorflow { @@ -28,11 +28,17 @@ namespace ifrt_serving { // Implements the custom backend compiler for IFRT based serving in TFRT. class IfrtBackendCompiler : public tensorflow::BackendCompiler { public: + explicit IfrtBackendCompiler(TpuCompiler* tpu_compiler = nullptr) + : tpu_compiler_(tpu_compiler) {} + // Rewrites the tensorflow graph in MLIR for IFRT serving. The methods // extracts regions for IFRT execution on accelerator (e.g. TPU). absl::Status CompileTensorflow( tensorflow::tfrt_stub::ModelRuntimeContext& model_context, mlir::ModuleOp module) const override; + + private: + TpuCompiler* tpu_compiler_; // Not owned. }; } // namespace ifrt_serving diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/sink_variable_as_named_array.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/sink_variable_as_named_array.cc index 90bdbf1f1ce6e8..b3bf510003e797 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/sink_variable_as_named_array.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/sink_variable_as_named_array.cc @@ -84,31 +84,60 @@ class SinkVariableAsNamedArrayPass } } + // TODO(b/332906178): collapse the below with the + // CollectVariablesUsedByDevice above or just remove the + // CollectVariablesUsedByDevice. + // // Rewrite ReadVariableOp with IfrtLoadVariableOp llvm::SmallDenseMap read_to_load; - for (auto& [name, variable_config] : variable_config_by_name) { - for (auto& read_variable_op : variable_config.read_variable_op) { - builder.setInsertionPointAfter(read_variable_op); - // TODO(b/319045348): consider use resource alias analysis for this. - auto var_handle = GetDefiningOp( - read_variable_op.getResource()); - - if (!var_handle) { - read_variable_op->emitError( - "ReadVariableOp has no defining VarHandleOp."); - return signalPassFailure(); - } - auto load_variable_op = builder.create( - read_variable_op->getLoc(), - mlir::RankedTensorType::get( - {}, builder.getType()), - var_handle.getResult(), - builder.getStringAttr(variable_config.device_sharding_config), - builder.getStringAttr(name)); - read_to_load[read_variable_op] = load_variable_op; - } + mlir::WalkResult walk_result = + module.walk([&](mlir::TF::ReadVariableOp read_variable_op) { + mlir::FailureOr variable_runtime_name = + GetVariableTensorName(read_variable_op); + if (mlir::failed(variable_runtime_name)) { + read_variable_op->emitError() << "Failed to get variable name."; + return mlir::WalkResult::interrupt(); + } + + builder.setInsertionPointAfter(read_variable_op); + // TODO(b/319045348): consider use resource alias analysis for + // this. + auto var_handle = GetDefiningOp( + read_variable_op.getResource()); + + if (!var_handle) { + read_variable_op->emitError( + "ReadVariableOp has no defining VarHandleOp."); + return mlir::WalkResult::interrupt(); + } + + auto iter = variable_config_by_name.find(*variable_runtime_name); + mlir::StringAttr device_sharding_config_attr; + if (iter == variable_config_by_name.end()) { + device_sharding_config_attr = builder.getStringAttr(""); + } else { + device_sharding_config_attr = + builder.getStringAttr(iter->second.device_sharding_config); + } + + std::vector result_types; + result_types.push_back(mlir::RankedTensorType::get( + {}, builder.getType())); + result_types.push_back(read_variable_op.getResult().getType()); + + auto load_variable_op = builder.create( + read_variable_op->getLoc(), result_types, var_handle.getResult(), + device_sharding_config_attr, + builder.getStringAttr(*variable_runtime_name)); + read_to_load[read_variable_op] = load_variable_op; + + return mlir::WalkResult::advance(); + }); + + if (walk_result.wasInterrupted()) { + return signalPassFailure(); } // Rewrite ifrt call: variable tensors are sunk as attribute. @@ -142,7 +171,7 @@ class SinkVariableAsNamedArrayPass variable_arg_indices.push_back(arg_idx); // Variable use the key from IfrtLoadVariable. updated_args.push_back( - read_to_load[arg.read_variable_op].getResult()); + read_to_load[arg.read_variable_op].getArrayKey()); } else { // non variable updated_args.push_back(call->getOperand(arg_idx)); @@ -162,14 +191,16 @@ class SinkVariableAsNamedArrayPass call.erase(); } - // Delete all ReadVariableOps that are not used. - for (auto& [name, variable_config] : variable_config_by_name) { - for (auto& read_variable_op : variable_config.read_variable_op) { - if (read_variable_op.use_empty()) { - read_variable_op.erase(); - } + // Remove all ReadVariableOp after replacing the CPU usage of + // ReadVariableOp. + module.walk([&](mlir::TF::ReadVariableOp read_variable_op) { + if (!read_variable_op->use_empty()) { + // Replace CPU use of ReadVariableOp + read_variable_op.replaceAllUsesWith( + read_to_load[read_variable_op].getTensorFuture()); } - } + read_variable_op.erase(); + }); } private: diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc index 9737c681d28aa8..1a76571e25bc03 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc @@ -73,6 +73,7 @@ void AddClusterToIfrtRuntimeOpsPassPipeline(OpPassManager& pm, pm.addNestedPass(CreateTfIdentityPropagationPass()); pm.addNestedPass(CreateTfRestoreSplittingPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass(CreateTfRestorePruningPass()); pm.addNestedPass(CreateTfRestoreMergingPass()); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD index 7d28571db5030a..d7fafb49ee6cdd 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD @@ -99,6 +99,7 @@ cc_library( ":async_while", ":fuse_mlrt_ops", ":parallelization", + ":rewrite_ifrt_load_variable", ":tf_to_mlrt", ":while_to_map_fn", "//tensorflow/compiler/mlir/tfrt:tfrt_pipeline_options", @@ -228,3 +229,24 @@ cc_library( "@llvm-project//mlir:TransformUtils", ], ) + +cc_library( + name = "rewrite_ifrt_load_variable", + srcs = ["rewrite_ifrt_load_variable.cc"], + hdrs = ["rewrite_ifrt_load_variable.h"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow/ir/host_runtime:tensorflow_tfrt_ops", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:mlrt_ops", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:tf_mlrt_ops", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc index b288ccde63c2f8..af932ff5011895 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc @@ -51,7 +51,7 @@ limitations under the License. namespace tensorflow { namespace mlrt_compiler { -StatusOr ConvertTfMlirToBytecode( +absl::StatusOr ConvertTfMlirToBytecode( const TfrtCompileOptions& options, tfrt_stub::FallbackState& fallback_state, mlir::ModuleOp module, tfrt_stub::ModelRuntimeContext& model_context, mlir::OwningOpRef* module_with_op_keys, @@ -132,13 +132,13 @@ StatusOr ConvertTfMlirToBytecode( auto statusor = mlrt::EmitExecutable(registry, module); if (!statusor.ok()) return statusor.status(); bytecode_buffer = std::move(*statusor); - return OkStatus(); + return absl::OkStatus(); }, model_context, &fallback_state, added_xla_function_names)); return bytecode_buffer; } -StatusOr ConvertTfMlirWithOpKeysToBytecode( +absl::StatusOr ConvertTfMlirWithOpKeysToBytecode( const TfrtCompileOptions& options, const tfrt_stub::FallbackState& fallback_state, mlir::ModuleOp module_with_op_keys, diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.cc index ac9606d4ee7f6c..eaa53c838e3796 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/async_while.h" #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/fuse_mlrt_ops.h" #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/parallelization.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.h" #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.h" #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.h" #include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" @@ -37,6 +38,7 @@ void RegisterMlrtPasses() { mlir::registerPass([]() { return CreateAsyncWhilePass(); }); mlir::registerPass([]() { return CreateParallelizationPass(); }); mlir::registerPass([]() { return CreateWhileToMapFnPass(); }); + mlir::registerPass([]() { return CreateRewriteIfrtLoadVariablePass(); }); mlir::registerPass( []() { return CreateTfToMlrtPreParallelizationConversionPass({}); }); mlir::registerPass([]() { return CreateTfToMlrtConversionPass({}); }); @@ -50,6 +52,8 @@ void CreateTfToMlrtPipeline(mlir::OpPassManager &pm, pm.addPass( mlrt_compiler::CreateTfToMlrtPreParallelizationConversionPass(options)); + pm.addPass(mlrt_compiler::CreateRewriteIfrtLoadVariablePass()); + if (options.enable_while_parallel_iterations) { pm.addPass(mlrt_compiler::CreateAsyncWhilePass()); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc new file mode 100644 index 00000000000000..368a91ac54f955 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc @@ -0,0 +1,105 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.h" + +#include +#include + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h" + +namespace tensorflow { +namespace mlrt_compiler { +namespace { + +class RewriteIfrtLoadVariablePass + : public mlir::PassWrapper> { + public: + RewriteIfrtLoadVariablePass() = default; + RewriteIfrtLoadVariablePass &operator=(const RewriteIfrtLoadVariablePass &) = + delete; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RewriteIfrtLoadVariablePass) + + private: + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + llvm::StringRef getArgument() const final { + return "tf-mlrt-rewrite-ifrt-load-variable"; + } + + llvm::StringRef getDescription() const final { + return "Convert tf.IfrtLoadVariable to tf_mlrt.TFIfrtLoadVariable"; + } + + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + mlir::OpBuilder builder(module); + + module->walk([&](mlir::TF::IfrtLoadVariableOp load_variable_op) { + builder.setInsertionPoint(load_variable_op); + + std::vector result_types; + result_types.push_back(load_variable_op.getArrayKey().getType()); + result_types.push_back(builder.getType()); + auto mlrt_load_variable_op = + builder.create( + load_variable_op->getLoc(), result_types, + load_variable_op->getOperands(), load_variable_op->getAttrs()); + for (auto user : load_variable_op.getTensorFuture().getUsers()) { + builder.setInsertionPoint(user); + auto await_op = builder.create( + user->getLoc(), load_variable_op.getTensorFuture().getType(), + mlrt_load_variable_op.getTensorFuture()); + user->replaceUsesOfWith(load_variable_op.getTensorFuture(), + await_op.getResult()); + } + + for (auto user : load_variable_op.getArrayKey().getUsers()) { + user->replaceUsesOfWith(load_variable_op.getArrayKey(), + mlrt_load_variable_op.getArrayKey()); + } + + load_variable_op->erase(); + }); + } +}; + +} // namespace + +std::unique_ptr> +CreateRewriteIfrtLoadVariablePass() { + return std::make_unique(); +} + +} // namespace mlrt_compiler +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.h b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.h new file mode 100644 index 00000000000000..1423011b05b01c --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_REWRITE_IFRT_LOAD_VARIABLE_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_REWRITE_IFRT_LOAD_VARIABLE_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace mlrt_compiler { + +// Creates a pass that converts tf.IfrtLoadVariableOp to +// tf_mlrt.TFIfrtLoadVariableOp and inserts tf_mlrt.Await on the returned future +// from tf_mlrt.TFIfrtLoadVariableOp if it is used by CPU ops. +std::unique_ptr> +CreateRewriteIfrtLoadVariablePass(); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_REWRITE_IFRT_LOAD_VARIABLE_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc index 37ddf0b1bf076d..350c424636b2f8 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc @@ -16,7 +16,6 @@ limitations under the License. #include -#include #include #include #include @@ -24,16 +23,20 @@ limitations under the License. #include #include "google/protobuf/text_format.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinDialect.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.h.inc" @@ -326,17 +329,24 @@ class GetResourceOpConversion final } }; -// Convert tf.IfrtLoadVariableOp to tf_mlrt.IfrtLoadVariableOp -class IfrtLoadVariableOpConversion - : public mlir::OpConversionPattern { +// Convert tf_mlrt.TFIfrtLoadVariableOp to tf_mlrt.IfrtLoadVariableOp +class TFIfrtLoadVariableOpConversion + : public mlir::OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + TFIfrtLoadVariableOpConversion(mlir::MLIRContext *context, + mlir::TypeConverter *type_converter) + : mlir::OpConversionPattern(context), + type_converter_(*type_converter) {} mlir::LogicalResult matchAndRewrite( - mlir::TF::IfrtLoadVariableOp op, OpAdaptor adaptor, + tf_mlrt::TFIfrtLoadVariableOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - llvm::SmallVector result_types( - op->getNumResults(), rewriter.getType()); + llvm::SmallVector result_types; + for (auto type : op->getResultTypes()) { + if (failed(type_converter_.convertType(type, result_types))) + return mlir::failure(); + } + auto new_op = rewriter.create( op.getLoc(), result_types, adaptor.getOperands()[0], op.getDeviceShardingConfigProtoTextAttr(), op.getNameAttr()); @@ -344,6 +354,9 @@ class IfrtLoadVariableOpConversion return mlir::success(); } + + private: + mlir::TypeConverter &type_converter_; }; // Convert tf.IfrtRestoreVariableOp to tf_mlrt.IfrtRestoreVariableOp @@ -523,7 +536,7 @@ class ExecuteOpConversion final : public mlir::ConversionPattern { node_def.device(), op->getNumOperands(), [&](tensorflow::AttrValueMap *attr_value_map) { *attr_value_map = node_def.attr(); - return OkStatus(); + return absl::OkStatus(); }, fallback_state_.device_manager(), fallback_state_.process_function_library_runtime()); @@ -1187,6 +1200,7 @@ class TfToMlrtConversionPass target.addIllegalDialect(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -1223,16 +1237,16 @@ class TfToMlrtConversionPass // Order the list of added ops alphabetically. patterns.add(&context, &type_converter_, &symbol_table); patterns.add(&context); + SetResourceOpConversion, IfrtRestoreVariableOpConversion, + TFAwaitOpConversion, TFPromiseOpConversion>(&context); patterns.add(type_converter_, &context); patterns.add(&context, &symbol_table, &type_converter_, &execute_op_registry_, &op_kernel_cache_, &fallback_state_); - patterns.add, + patterns.add, TFCallOpConversion, TFCallOpConversion>(&context, &type_converter_); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc index a1f9d401f5c485..5ab6678da892a4 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc @@ -34,7 +34,7 @@ bool UseFallback(mlir::Operation *op) { // TF kernels so that we don't need to check every op here. return !llvm::isa< mlir::TF::_TfrtSetResourceOp, mlir::TF::_TfrtGetResourceOp, - mlir::TF::BatchFunctionOp, mlir::TF::CaseOp, mlir::TF::IfrtLoadVariableOp, + mlir::TF::BatchFunctionOp, mlir::TF::CaseOp, mlir::TF::IfrtRestoreVariableOp, mlir::TF::StatefulPartitionedCallOp, mlir::TF::PartitionedCallOp, mlir::TF::LegacyCallOp, mlir::TF::IfOp, mlir::TF::WhileOp, mlir::TF::TPUCompileMlirAndExecuteOp>(op); diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index f61a087e782704..3cf8be9c90cb62 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -80,7 +80,7 @@ namespace { // Exports all XLA functions in the form of XlaLaunch, and their nested // functions. -StatusOr> ExportXlaFunctions( +absl::StatusOr> ExportXlaFunctions( mlir::ModuleOp module, std::vector* added_xla_function_names) { // Find all XLA functions. std::vector xla_functions; @@ -306,7 +306,7 @@ Status ConvertTfMlirToBef(const TfrtCompileOptions& options, absl::InternalError("failed to convert MLIR to BEF.")); bef_buffer->shrink_to_fit(); - return OkStatus(); + return absl::OkStatus(); }, model_context, fallback_state, added_xla_function_names); } @@ -364,7 +364,7 @@ tensorflow::Status AddXlaFunctions( } } - return tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/utils/export.cc b/tensorflow/compiler/mlir/tfrt/utils/export.cc index a26008731405f7..182fdb3008193a 100644 --- a/tensorflow/compiler/mlir/tfrt/utils/export.cc +++ b/tensorflow/compiler/mlir/tfrt/utils/export.cc @@ -37,7 +37,8 @@ namespace tensorflow { absl::Status ExportFunctionDefs( mlir::ModuleOp module, - absl::AnyInvocable callback) { + absl::AnyInvocable callback, + bool export_tf_original_func_name) { tsl::profiler::TraceMe traceme([&]() { return tsl::profiler::TraceMeEncode( "ExportFunctionDefs", @@ -58,7 +59,7 @@ absl::Status ExportFunctionDefs( } } tensorflow::GraphExportConfig configs; - configs.export_original_tf_func_name = true; + configs.export_original_tf_func_name = export_tf_original_func_name; for (auto func : module.getOps()) { tensorflow::FunctionDef function_def; diff --git a/tensorflow/compiler/mlir/tfrt/utils/export.h b/tensorflow/compiler/mlir/tfrt/utils/export.h index 7a226974bffcf6..84f0e272d4f828 100644 --- a/tensorflow/compiler/mlir/tfrt/utils/export.h +++ b/tensorflow/compiler/mlir/tfrt/utils/export.h @@ -15,7 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_UTILS_EXPORT_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_UTILS_EXPORT_H_ -#include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" @@ -29,7 +28,8 @@ namespace tensorflow { // be suitable for FunctionDef export. absl::Status ExportFunctionDefs( mlir::ModuleOp module, - absl::AnyInvocable callback); + absl::AnyInvocable callback, + bool export_tf_original_func_name = true); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 86e2e269e4d329..1069f3fd172411 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -129,6 +129,7 @@ tf_cc_binary( "@llvm-project//llvm:TargetParser", "@llvm-project//mlir:BufferizationInterfaces", "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc index d5c29e90ef7ed1..7c53bc23fda464 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc @@ -24,19 +24,26 @@ #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/SmallString.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/CodeGen/CommandFlags.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" #include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/CodeGen.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/TargetParser/Host.h" #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" // from @llvm-project #include "mlir/ExecutionEngine/OptUtils.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project @@ -45,7 +52,8 @@ #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/statusor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace tensorflow { namespace kernel_gen { @@ -149,7 +157,7 @@ Status Run(llvm::StringRef input_file, llvm::StringRef output_file, // Write .a file. TF_RETURN_IF_ERROR( WriteStringToFile(Env::Default(), output_file.str(), binary)); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index eb6c06ac54f9c4..277511fed098e0 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -144,7 +144,7 @@ Status LowerHloToJITInvocation(mlir::ModuleOp module, if (failed(pm.run(module))) { return absl::InternalError("Lowering HLO to JIT invocation failed."); } - return OkStatus(); + return absl::OkStatus(); } Status LowerHlotoLoops(mlir::ModuleOp module, @@ -236,7 +236,7 @@ Status LowerHlotoLoops(mlir::ModuleOp module, if (failed(pm.run(module))) { return absl::InternalError("Lowering HLO to loops failed."); } - return OkStatus(); + return absl::OkStatus(); } Status LowerLoopsToGPU(mlir::ModuleOp module, bool index_64bit, @@ -305,7 +305,7 @@ Status LowerLoopsToGPU(mlir::ModuleOp module, bool index_64bit, if (failed(pm.run(module))) { return absl::InternalError("Lowering to GPU kernels failed."); } - return OkStatus(); + return absl::OkStatus(); } Status LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module, @@ -350,7 +350,7 @@ Status LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module, "Lowering to low-level device IR failed."); } - return OkStatus(); + return absl::OkStatus(); } Status AmendKernelLLVMIRWithStaticKnowledge(mlir::ModuleOp module, @@ -366,7 +366,7 @@ Status AmendKernelLLVMIRWithStaticKnowledge(mlir::ModuleOp module, return failed(pm.run(module)) ? tensorflow::errors::Internal( "Amending LLVMIR with static knowledge failed.") - : OkStatus(); + : absl::OkStatus(); } Status GenerateDeviceCode(mlir::ModuleOp module, @@ -387,7 +387,7 @@ Status GenerateDeviceCode(mlir::ModuleOp module, return failed(pm.run(module)) ? tensorflow::errors::Internal("Generating device code failed.") - : OkStatus(); + : absl::OkStatus(); } Status LowerHostSideToFinalForm(mlir::ModuleOp module, bool apply_cl_options) { @@ -402,7 +402,7 @@ Status LowerHostSideToFinalForm(mlir::ModuleOp module, bool apply_cl_options) { return failed(pm.run(module)) ? tensorflow::errors::Internal( "Final lowering of host side failed.") - : OkStatus(); + : absl::OkStatus(); } } // namespace diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index c4abb6420d9b38..d1c3af0b9a6191 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -166,7 +166,7 @@ cc_library( "@local_xla//xla/service/gpu/llvm_gpu_backend", ] + if_cuda_is_configured([ "@local_tsl//tsl/platform:cuda_libdevice_path", - "@local_xla//xla/stream_executor/gpu:asm_compiler", + "@local_xla//xla/stream_executor/cuda:cuda_asm_compiler", ]) + if_rocm_is_configured([ "@local_xla//xla/stream_executor/gpu:asm_compiler", "//tensorflow/core/platform:rocm_rocdl_path", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc index 345eccfd12c5f0..58f3c195e900d1 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -39,7 +39,7 @@ limitations under the License. #include "tsl/platform/cuda_libdevice_path.h" #if GOOGLE_CUDA -#include "xla/stream_executor/gpu/asm_compiler.h" +#include "xla/stream_executor/cuda/cuda_asm_compiler.h" #elif TENSORFLOW_USE_ROCM #include "xla/stream_executor/gpu/asm_compiler.h" #include "tensorflow/core/platform/rocm_rocdl_path.h" diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 8847aa1a811455..6fb2158b00e3a5 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -281,7 +281,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, *test_node_def = test_def; } - return OkStatus(); + return absl::OkStatus(); } // Test fixture. The fixture manages the random number generator and its seed, @@ -1386,7 +1386,7 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol, " rtol = ", rtol, " tol = ", atol + rtol * Abs(Tx(i)))); } } - return OkStatus(); + return absl::OkStatus(); } template @@ -1400,7 +1400,7 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString())); } } - return OkStatus(); + return absl::OkStatus(); } Status TensorsAreEqualImplBfloat16(const Tensor& x, const Tensor& y) { @@ -1414,7 +1414,7 @@ Status TensorsAreEqualImplBfloat16(const Tensor& x, const Tensor& y) { "y = ", y.DebugString())); } } - return OkStatus(); + return absl::OkStatus(); } // Tests if "x" and "y" are tensors of the same type, same shape, and with diff --git a/tensorflow/compiler/tf2xla/light_outside_compilation_kernels_for_test.cc b/tensorflow/compiler/tf2xla/light_outside_compilation_kernels_for_test.cc index 07ed120620e8be..5cc05693996934 100644 --- a/tensorflow/compiler/tf2xla/light_outside_compilation_kernels_for_test.cc +++ b/tensorflow/compiler/tf2xla/light_outside_compilation_kernels_for_test.cc @@ -32,7 +32,7 @@ REGISTER_OP("TestStaticTf") .Output("output: float") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); - return OkStatus(); + return absl::OkStatus(); }); class TestStaticTfOp : public OpKernel { @@ -69,7 +69,7 @@ REGISTER_OP("TestStaticMultipleOutputTf") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); c->set_output(1, c->input(0)); - return OkStatus(); + return absl::OkStatus(); }); class TestStaticMultipleOutputTfOp : public OpKernel { @@ -117,7 +117,7 @@ REGISTER_OP("TestDynamicTf") .Output("output: float") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0)))); - return OkStatus(); + return absl::OkStatus(); }); // Same as TestStaticTfOp, but only copies up to `max_size` attribute. @@ -183,7 +183,7 @@ REGISTER_OP("DynamicMultidim") .Output("output: float") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->UnknownShapeOfRank(5)); - return OkStatus(); + return absl::OkStatus(); }); // Just fill in the data with ones for a given shape. @@ -245,7 +245,7 @@ REGISTER_OP("DynamicUnranked") .Output("output: float") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); }); REGISTER_XLA_OP(Name("DynamicUnranked").Device(DEVICE_GPU_XLA_JIT), @@ -258,7 +258,7 @@ REGISTER_OP("TestTfMustBeConstant") .Output("output: float") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); - return OkStatus(); + return absl::OkStatus(); }); class TestTfMustBeConstantOp : public OpKernel { @@ -318,7 +318,7 @@ REGISTER_OP("TestDynamicTfWithBound") .Output("output: float") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); - return OkStatus(); + return absl::OkStatus(); }); class TestDynamicTfWithBoundOp : public OpKernel { diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index b7dc020187883e..e00ad8eb92132e 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -40,7 +40,7 @@ Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape, << "Provided xla::Shape must have the same dims as the Tensor shape."; *literal = xla::BorrowingLiteral( static_cast(DMAHelper::base(&host_tensor)), xla_shape); - return OkStatus(); + return absl::OkStatus(); } absl::StatusOr HostTensorToLiteral(const Tensor& host_tensor) { @@ -63,7 +63,7 @@ Status HostTensorToMutableBorrowingLiteral( *literal = xla::MutableBorrowingLiteral( static_cast(DMAHelper::base(host_tensor)), xla_shape); - return OkStatus(); + return absl::OkStatus(); } Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, @@ -83,7 +83,7 @@ Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, *literal = xla::BorrowingLiteral( buf_ptrs, xla::ShapeUtil::MakeTupleShape(tensor_shapes)); - return OkStatus(); + return absl::OkStatus(); } Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, @@ -106,7 +106,7 @@ Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, void* dst_ptr = DMAHelper::base(host_tensor); memcpy(dst_ptr, src_ptr, total_bytes); } - return OkStatus(); + return absl::OkStatus(); } Status LiteralToHostTensor(const xla::LiteralSlice& literal, diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index c24654c894b34f..d1a2a68d045bfc 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h" #include "tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h" -// #include "tensorflow/compiler/tf2xla/tf2xla_defs.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/device.h" diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 656c02e1214ac6..29e0de5edafbc2 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -646,7 +646,6 @@ cc_library( "//tensorflow/core/kernels/mkl:mkl_matmul_op", "//tensorflow/core/kernels/mkl:mkl_sparse_matrix_matmul_op", "//tensorflow/core/kernels/mkl:mkl_tmp_ops", - "//tensorflow/core/kernels/mkl:mkl_deprecated_ops", ]) + if_cuda_or_rocm([ "//tensorflow/core/kernels:cudnn_rnn_kernels", ]) + if_cuda([ @@ -662,9 +661,7 @@ cc_library( cc_library( name = "dynamic_kernels_impl", - visibility = [ - "//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private. - ], + visibility = ["//visibility:private"], deps = [ "//tensorflow/core/kernels:sobol_op", ], @@ -1993,7 +1990,6 @@ tf_cc_test_mkl( "//tensorflow/core/kernels/mkl:mkl_softmax_op", "//tensorflow/core/kernels/mkl:mkl_transpose_op", "//tensorflow/core/kernels/mkl:mkl_tmp_ops", - "//tensorflow/core/kernels/mkl:mkl_deprecated_ops", ]), ) diff --git a/tensorflow/core/api_def/base_api/api_def_AssignVariableXlaConcatND.pbtxt b/tensorflow/core/api_def/base_api/api_def_AssignVariableXlaConcatND.pbtxt index 646f602af22688..6bd6bcd8d05ad5 100644 --- a/tensorflow/core/api_def/base_api/api_def_AssignVariableXlaConcatND.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_AssignVariableXlaConcatND.pbtxt @@ -5,17 +5,13 @@ op { name: "resource" description: <